/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.sql2rel;

import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.linq4j.function.Function2;
import org.apache.calcite.plan.Context;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCostImpl;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.Strong;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.plan.hep.HepProgramBuilder;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.BiRel;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.logical.LogicalSnapshot;
import org.apache.calcite.rel.logical.LogicalTableFunctionScan;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.rel.rules.FilterCorrelateRule;
import org.apache.calcite.rel.rules.FilterFlattenCorrelatedConditionRule;
import org.apache.calcite.rel.rules.FilterJoinRule;
import org.apache.calcite.rel.rules.FilterProjectTransposeRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlExplainFormat;
import org.apache.calcite.sql.SqlExplainLevel;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlCountAggFunction;
import org.apache.calcite.sql.fun.SqlSingleValueAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.tools.RuleSet;
import org.apache.calcite.util.Holder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Litmus;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.ReflectUtil;
import org.apache.calcite.util.ReflectiveVisitor;
import org.apache.calcite.util.Util;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.calcite.util.trace.CalciteTrace;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.Sets;
import com.google.common.collect.SortedSetMultimap;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;
import org.slf4j.Logger;

import java.math.BigDecimal;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;

import static org.apache.calcite.linq4j.Nullness.castNonNull;

import static java.util.Objects.requireNonNull;

/**
 * RelDecorrelator replaces all correlated expressions (corExp) in a relational
 * expression (RelNode) tree with non-correlated expressions that are produced
 * from joining the RelNode that produces the corExp with the RelNode that
 * references it.
 *
 * <p>TODO:
 * <ul>
 *   <li>replace {@code CorelMap} constructor parameter with a RelNode
 *   <li>make {@link #currentRel} immutable (would require a fresh
 *      RelDecorrelator for each node being decorrelated)</li>
 *   <li>make fields of {@code CorelMap} immutable</li>
 * </ul>
 *
 * <p>Note: make all the members protected scope so that they can be
 * accessed by the sub-class.
 */
public class RelDecorrelator implements ReflectiveVisitor {
  //~ Static fields/initializers ---------------------------------------------

  private static final Logger SQL2REL_LOGGER =
      CalciteTrace.getSqlToRelTracer();

  //~ Instance fields --------------------------------------------------------

  protected final RelBuilder relBuilder;

  // map built during translation
  protected CorelMap cm;

  /** Stack maintaining visible Frames to the currently invoked RelNode during top-down traversal.
   *  Each entry maps a CorrelationId to the Frame where its correlated variables originate. */
  protected final Deque<Pair<CorrelationId, Frame>> frameStack = new ArrayDeque<>();

  @SuppressWarnings("method.invocation.invalid")
  protected final ReflectUtil.MethodDispatcher<@Nullable Frame> dispatcher =
      ReflectUtil.<RelNode, @Nullable Frame>createMethodDispatcher(
          Frame.class, getVisitor(), "decorrelateRel",
          RelNode.class,
          boolean.class,
          boolean.class);

  // The rel which is being visited
  protected @Nullable RelNode currentRel;

  protected final Context context;

  /** Built during decorrelation, of rel to all the newly created correlated
   * variables in its output, and to map old input positions to new input
   * positions. This is from the view point of the parent rel of a new rel. */
  protected final Map<RelNode, Frame> map = new HashMap<>();

  protected final HashSet<Correlate> generatedCorRels = new HashSet<>();

  //~ Constructors -----------------------------------------------------------

  protected RelDecorrelator(
      CorelMap cm,
      Context context,
      RelBuilder relBuilder) {
    this.cm = cm;
    this.context = context;
    this.relBuilder = relBuilder;
  }

  //~ Methods ----------------------------------------------------------------

  @Deprecated // to be removed before 2.0
  public static RelNode decorrelateQuery(RelNode rootRel) {
    final RelBuilder relBuilder =
        RelFactories.LOGICAL_BUILDER.create(rootRel.getCluster(), null);
    return decorrelateQuery(rootRel, relBuilder);
  }

  /** Decorrelates a query.
   *
   * <p>This is the main entry point to {@code RelDecorrelator}.
   *
   * @param rootRel           Root node of the query
   * @param relBuilder        Builder for relational expressions
   *
   * @return Equivalent query with all
   * {@link org.apache.calcite.rel.core.Correlate} instances removed
   */
  public static RelNode decorrelateQuery(RelNode rootRel,
      RelBuilder relBuilder) {
    return decorrelateQuery(rootRel, relBuilder, null);
  }

  public static RelNode decorrelateQuery(RelNode rootRel,
      RelBuilder relBuilder, @Nullable RuleSet decorrelationRules) {
    return decorrelateQuery(rootRel, relBuilder, decorrelationRules, null);
  }

  /**
   * Decorrelates a query specifying a set of rules to be used in the
   * "remove correlation via rules" pre-processing.
   *
   * @param rootRel           Root node of the query
   * @param relBuilder        Builder for relational expressions
   * @param decorrelationRules  Rules to attempt some initial rule-based-decorrelation conversions,
   *                            if <code>null</code> a default rule set will be used
   * @param preDecorrelateRules Pre-process rules to be used before the main decorrelation
   *                            procedure, if <code>null</code> a default rule set will be used
   *
   * @return Equivalent query with all
   * {@link org.apache.calcite.rel.core.Correlate} instances removed
   *
   * @see #removeCorrelationViaRule(RelNode, RuleSet)
   */
  public static RelNode decorrelateQuery(RelNode rootRel,
      RelBuilder relBuilder, @Nullable RuleSet decorrelationRules,
      @Nullable RuleSet preDecorrelateRules) {
    final CorelMap corelMap = new CorelMapBuilder().build(rootRel);
    if (!corelMap.hasCorrelation()) {
      return rootRel;
    }

    final RelOptCluster cluster = rootRel.getCluster();
    final RelDecorrelator decorrelator =
        new RelDecorrelator(corelMap,
            cluster.getPlanner().getContext(), relBuilder);

    RelNode newRootRel = decorrelationRules == null
        ? decorrelator.removeCorrelationViaRule(rootRel)
        : decorrelator.removeCorrelationViaRule(rootRel, decorrelationRules);

    if (SQL2REL_LOGGER.isDebugEnabled()) {
      SQL2REL_LOGGER.debug(
          RelOptUtil.dumpPlan("Plan after removing Correlator", newRootRel,
              SqlExplainFormat.TEXT, SqlExplainLevel.EXPPLAN_ATTRIBUTES));
    }

    if (!decorrelator.cm.mapCorToCorRel.isEmpty()) {
      newRootRel = decorrelator.decorrelate(newRootRel, preDecorrelateRules);
    }
    Litmus.THROW.check(
        rootRel.getRowType().equalsSansFieldNames(newRootRel.getRowType()),
        "Decorrelation produced a relation with a different type; before: "
            + rootRel.getRowType() + " after: " + newRootRel.getRowType());

    // Re-propagate the hints.
    newRootRel = RelOptUtil.propagateRelHints(newRootRel, true);
    return newRootRel;
  }

  private void setCurrent(@Nullable RelNode root, @Nullable Correlate corRel) {
    currentRel = corRel;
    if (corRel != null) {
      cm = new CorelMapBuilder().build(Util.first(root, corRel));
    }
  }

  protected RelBuilderFactory relBuilderFactory() {
    return RelBuilder.proto(relBuilder);
  }

  protected RelNode decorrelate(RelNode root) {
    return decorrelate(root, null);
  }

  protected RelNode decorrelate(RelNode root, @Nullable RuleSet preDecorrelateRules) {
    final RelBuilderFactory f = relBuilderFactory();
    final HepProgram program;
    if (preDecorrelateRules != null) {
      program = ruleSetToHepProgram(preDecorrelateRules);
    } else {
      // Use a default set of pre-decorrelate rules:
      // adjust count() expression if any, and do some filter-related transformations
      program = HepProgram.builder()
          .addRuleInstance(
              AdjustProjectForCountAggregateRule.DEFAULT_WITHOUT_FAVLOR
                  .withRelBuilderFactory(f).toRule())
          .addRuleInstance(
              AdjustProjectForCountAggregateRule.DEFAULT_WITH_FAVLOR
                  .withRelBuilderFactory(f).toRule())
          .addRuleInstance(
              FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.DEFAULT
                  .withRelBuilderFactory(f)
                  .withOperandSupplier(b0 ->
                      b0.operand(Filter.class).oneInput(b1 ->
                          b1.operand(Join.class).anyInputs()))
                  .withDescription("FilterJoinRule:filter")
                  .as(FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.class)
                  .withSmart(true)
                  .withPredicate((join, joinType, exp) -> true)
                  .as(FilterJoinRule.FilterIntoJoinRule.FilterIntoJoinRuleConfig.class)
                  .toRule())
          .addRuleInstance(
              CoreRules.FILTER_PROJECT_TRANSPOSE.config
                  .withRelBuilderFactory(f)
                  .as(FilterProjectTransposeRule.Config.class)
                  .withOperandFor(Filter.class, filter ->
                          !RexUtil.containsCorrelation(filter.getCondition()),
                      Project.class, project -> true)
                  .withCopyFilter(true)
                  .withCopyProject(true)
                  .toRule())
          .addRuleInstance(FilterCorrelateRule.Config.DEFAULT
              .withRelBuilderFactory(f)
              .toRule())
          .addRuleInstance(FilterFlattenCorrelatedConditionRule.Config.DEFAULT
              .withRelBuilderFactory(f)
              .toRule())
          .build();
    }

    root = applyHepProgram(root, program);
    if (SQL2REL_LOGGER.isDebugEnabled()) {
      SQL2REL_LOGGER.debug("Plan before extracting correlated computations:\n"
          + RelOptUtil.toString(root));
    }
    root = root.accept(new CorrelateProjectExtractor(f));
    // Necessary to update cm (CorrelMap) since CorrelateProjectExtractor above may modify the plan
    this.cm = new CorelMapBuilder().build(root);
    if (SQL2REL_LOGGER.isDebugEnabled()) {
      SQL2REL_LOGGER.debug("Plan after extracting correlated computations:\n"
          + RelOptUtil.toString(root));
    }
    // Perform decorrelation.
    map.clear();

    final Frame frame = getInvoke(root, false, null, true);
    if (frame != null) {
      // Check if the frame has more fields than the original and discard the extra ones
      RelNode result = frame.r;
      int fields = frame.r.getRowType().getFieldCount();
      if (fields > frame.oldToNewOutputs.size()) {
        relBuilder.push(result);
        final List<RexNode> exprList = new ArrayList<>();
        List<Map.Entry<Integer, Integer>> entries =
            new ArrayList<>(frame.oldToNewOutputs.entrySet());
        entries.sort(Map.Entry.comparingByKey());
        for (Map.Entry<Integer, Integer> entry : entries) {
          exprList.add(relBuilder.field(entry.getValue()));
        }
        relBuilder.project(exprList);
        result = relBuilder.build();
      } else {
        Litmus.THROW.check(fields == frame.oldToNewOutputs.size(),
            "Produced relation has fewer columns than the original relation");
      }

      // has been rewritten; apply rules post-decorrelation
      final HepProgramBuilder builder = HepProgram.builder()
          .addRuleInstance(
              CoreRules.FILTER_INTO_JOIN.config
                  .withRelBuilderFactory(f)
                  .toRule())
          .addRuleInstance(
              CoreRules.JOIN_CONDITION_PUSH.config
                  .withRelBuilderFactory(f)
                  .toRule());
      if (!getPostDecorrelateRules().isEmpty()) {
        builder.addRuleCollection(getPostDecorrelateRules());
      }
      final HepProgram program2 = builder.build();
      return applyHepProgram(result, program2);
    }

    return root;
  }

  private Function2<RelNode, RelNode, @Nullable Void> createCopyHook() {
    return (oldNode, newNode) -> {
      if (cm.mapRefRelToCorRef.containsKey(oldNode)) {
        cm.mapRefRelToCorRef.putAll(newNode,
            cm.mapRefRelToCorRef.get(oldNode));
      }
      if (oldNode instanceof Correlate
          && newNode instanceof Correlate) {
        Correlate oldCor = (Correlate) oldNode;
        CorrelationId c = oldCor.getCorrelationId();
        if (cm.mapCorToCorRel.get(c) == oldNode) {
          cm.mapCorToCorRel.put(c, newNode);
        }

        if (generatedCorRels.contains(oldNode)) {
          generatedCorRels.add((Correlate) newNode);
        }
      }
      return null;
    };
  }

  private HepPlanner createPlanner(HepProgram program) {
    // Create a planner with a hook to update the mapping tables when a
    // node is copied when it is registered.
    HepPlanner planner =
        new HepPlanner(
            program,
            context,
            true,
            createCopyHook(),
            RelOptCostImpl.FACTORY);
    planner.setDecorrelator(this);
    return planner;
  }

  /**
   * Remove some instances of {@link org.apache.calcite.rel.core.Correlate} from a query plan
   * by applying a default set of rules (only some of the
   * {@link org.apache.calcite.rel.core.Correlate}s might be removable in such way).
   */
  public RelNode removeCorrelationViaRule(RelNode root) {
    final RelBuilderFactory f = relBuilderFactory();
    HepProgram program = HepProgram.builder()
        .addRuleInstance(RemoveSingleAggregateRule.DEFAULT.toRule())
        .addRuleInstance(
            RemoveCorrelationForScalarProjectRule.DEFAULT.withRelBuilderFactory(f).toRule())
        .addRuleInstance(
            RemoveCorrelationForScalarAggregateRule.DEFAULT.withRelBuilderFactory(f).toRule())
        .build();
    return applyHepProgram(root, program);
  }

  /**
   * Remove some instances of {@link org.apache.calcite.rel.core.Correlate} from a query plan
   * by applying a certain {@link RuleSet} (only some of the
   * {@link org.apache.calcite.rel.core.Correlate}s might be removable in such way).
   */
  public RelNode removeCorrelationViaRule(RelNode root, RuleSet ruleSet) {
    return applyHepProgram(root, ruleSetToHepProgram(ruleSet));
  }

  private HepProgram ruleSetToHepProgram(RuleSet ruleSet) {
    final RelBuilderFactory f = relBuilderFactory();
    final HepProgramBuilder builder = HepProgram.builder();
    for (RelOptRule rule : ruleSet) {
      if (rule instanceof RelRule) {
        rule = ((RelRule<?>) rule).config.withRelBuilderFactory(f).toRule();
      }
      builder.addRuleInstance(rule);
    }
    return builder.build();
  }

  private RelNode applyHepProgram(RelNode root, HepProgram program) {
    HepPlanner planner = createPlanner(program);
    planner.setRoot(root);
    return planner.findBestExp();
  }

  protected RexNode decorrelateExpr(RelNode currentRel,
      Map<RelNode, Frame> map, CorelMap cm, RexNode exp) {
    DecorrelateRexShuttle shuttle =
        new DecorrelateRexShuttle(currentRel, map, cm);
    return exp.accept(shuttle);
  }

  protected RexNode removeCorrelationExpr(
      RexNode exp,
      boolean projectPulledAboveLeftCorrelator) {
    RemoveCorrelationRexShuttle shuttle =
        new RemoveCorrelationRexShuttle(relBuilder.getRexBuilder(),
            projectPulledAboveLeftCorrelator, null, ImmutableSet.of());
    return exp.accept(shuttle);
  }

  protected RexNode removeCorrelationExpr(
      RexNode exp,
      boolean projectPulledAboveLeftCorrelator,
      RexInputRef nullIndicator) {
    RemoveCorrelationRexShuttle shuttle =
        new RemoveCorrelationRexShuttle(relBuilder.getRexBuilder(),
            projectPulledAboveLeftCorrelator, nullIndicator,
            ImmutableSet.of());
    return exp.accept(shuttle);
  }

  protected RexNode removeCorrelationExpr(
      RexNode exp,
      boolean projectPulledAboveLeftCorrelator,
      Set<Integer> isCount) {
    RemoveCorrelationRexShuttle shuttle =
        new RemoveCorrelationRexShuttle(relBuilder.getRexBuilder(),
            projectPulledAboveLeftCorrelator, null, isCount);
    RexNode exp2 = exp.accept(shuttle);

    // Fix the nullability.
    if (projectPulledAboveLeftCorrelator) {
      exp2 = relBuilder.getRexBuilder().makeNullable(exp2);
    }
    return exp2;
  }

  /** Fallback if none of the other {@code decorrelateRel} methods match. */
  public @Nullable Frame decorrelateRel(RelNode rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    RelNode newRel = rel.copy(rel.getTraitSet(), rel.getInputs());

    if (!rel.getInputs().isEmpty()) {
      List<RelNode> oldInputs = rel.getInputs();
      List<RelNode> newInputs = new ArrayList<>();
      for (int i = 0; i < oldInputs.size(); ++i) {
        final Frame frame =
            getInvoke(oldInputs.get(i), isCorVarDefined, rel, parentPropagatesNullValues);
        if (frame == null || !frame.corDefOutputs.isEmpty()) {
          // if input is not rewritten, or if it produces correlated
          // variables, terminate rewrite
          return null;
        }
        newInputs.add(frame.r);
        newRel.replaceInput(i, frame.r);
      }

      if (!Util.equalShallow(oldInputs, newInputs)) {
        newRel = rel.copy(rel.getTraitSet(), newInputs);
      }
    }

    // the output position should not change since there are no corVars
    // coming from below.
    return register(rel, newRel, identityMap(rel.getRowType().getFieldCount()),
        ImmutableSortedMap.of());
  }

  public @Nullable Frame decorrelateRel(Sort rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    //
    // Rewrite logic:
    //
    // 1. change the collations field to reference the new input.
    //

    // Sort itself should not reference corVars.
    assert !cm.mapRefRelToCorRef.containsKey(rel);

    // Sort only references field positions in collations field.
    // The collations field in the newRel now need to refer to the
    // new output positions in its input.
    // Its output does not change the input ordering, so there's no
    // need to call propagateExpr.

    final RelNode oldInput = rel.getInput();
    final Frame frame = getInvoke(oldInput, isCorVarDefined, rel, true);
    if (frame == null) {
      // If input has not been rewritten, do not rewrite this rel.
      return null;
    }

    if (isCorVarDefined && (rel.fetch != null || rel.offset != null)) {
      if (rel.fetch != null
          && rel.offset == null
          && RexLiteral.intValue(rel.fetch) == 1) {
        return decorrelateFetchOneSort(rel, frame);
      }
      // Can not decorrelate if the sort has per-correlate-key attributes like
      // offset or fetch limit, because these attributes scope would change to
      // global after decorrelation. They should take effect within the scope
      // of the correlation key actually.
      return null;
    }

    final RelNode newInput = frame.r;

    Mappings.TargetMapping mapping =
        Mappings.target(frame.oldToNewOutputs,
            oldInput.getRowType().getFieldCount(),
            newInput.getRowType().getFieldCount());

    RelCollation oldCollation = rel.getCollation();
    RelCollation newCollation = RexUtil.apply(mapping, oldCollation);

    final RelNode newSort = relBuilder
        .push(newInput)
        .sortLimit(rel.offset, rel.fetch, relBuilder.fields(newCollation))
        .build();

    // Sort does not change input ordering
    return register(rel, newSort, frame.oldToNewOutputs, frame.corDefOutputs);
  }

  public @Nullable Frame decorrelateRel(LogicalAggregate rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    return decorrelateRel((Aggregate) rel, isCorVarDefined, parentPropagatesNullValues);
  }

  public @Nullable Frame decorrelateRel(Aggregate rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    //
    // Rewrite logic:
    //
    // 1. Permute the group by keys to the front.
    // 2. If the input of an aggregate produces correlated variables,
    //    add them to the group list.
    // 3. Change aggCalls to reference the new project.
    //

    // Aggregate itself should not reference corVars.
    assert !cm.mapRefRelToCorRef.containsKey(rel);

    final RelNode oldInput = rel.getInput();
    final Frame frame = getInvoke(oldInput, isCorVarDefined, rel, parentPropagatesNullValues);
    if (frame == null) {
      // If input has not been rewritten, do not rewrite this rel.
      return null;
    }
    final RelNode newInput = frame.r;

    // aggregate outputs mapping: group keys and aggregates
    final Map<Integer, Integer> outputMap = new HashMap<>();

    // map from newInput
    final Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
    final int oldGroupKeyCount = rel.getGroupSet().cardinality();

    // Project projects the original expressions,
    // plus any correlated variables the input wants to pass along.
    final PairList<RexNode, String> projects = PairList.of();

    List<RelDataTypeField> newInputOutput =
        newInput.getRowType().getFieldList();

    int newPos = 0;

    final List<Integer> groupKeyIndices = rel.getGroupSet().asList();
    final NavigableMap<Integer, RexLiteral> omittedConstants = new TreeMap<>();
    for (int i = 0; i < oldGroupKeyCount; i++) {
      final int idx = groupKeyIndices.get(i);
      final RexLiteral constant = projectedLiteral(newInput, idx);
      if (constant != null) {
        // Exclude constants. Aggregate({true}) occurs because Aggregate({})
        // would generate 1 row even when applied to an empty table.
        omittedConstants.put(idx, constant);
        continue;
      }

      // add mapping of group keys.
      outputMap.put(i, newPos);
      int newInputPos = requireNonNull(frame.oldToNewOutputs.get(idx));
      RexInputRef.add2(projects, newInputPos, newInputOutput);
      mapNewInputToProjOutputs.put(newInputPos, newPos);
      newPos++;
    }

    final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
    if (!frame.corDefOutputs.isEmpty()) {
      // If input produces correlated variables, move them to the front,
      // right after any existing GROUP BY fields.

      // Now add the corVars from the input, starting from
      // position oldGroupKeyCount.
      for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
        // Verify if the CorDef position was already added to the mapNewInputToProjOutputs
        // during the previous group key processing
        final Integer pos = mapNewInputToProjOutputs.get(entry.getValue());
        if (pos == null) {
          RexInputRef.add2(projects, entry.getValue(), newInputOutput);
          corDefOutputs.put(entry.getKey(), newPos);
          mapNewInputToProjOutputs.put(entry.getValue(), newPos);
          newPos++;
        } else {
          corDefOutputs.put(entry.getKey(), pos);
        }
      }
    }

    // add the remaining fields
    final int newGroupKeyCount = newPos;
    for (int i = 0; i < newInputOutput.size(); i++) {
      if (!mapNewInputToProjOutputs.containsKey(i)) {
        RexInputRef.add2(projects, i, newInputOutput);
        mapNewInputToProjOutputs.put(i, newPos);
        newPos++;
      }
    }

    // This Project will be what the old input maps to,
    // replacing any previous mapping from old input).
    RelNode newProject = relBuilder.push(newInput)
        .projectNamed(projects.leftList(), projects.rightList(), true)
        .build();

    // update mappings:
    // oldInput ----> newInput
    //
    //                newProject
    //                   |
    // oldInput ----> newInput
    //
    // is transformed to
    //
    // oldInput ----> newProject
    //                   |
    //                newInput
    Map<Integer, Integer> combinedMap = new HashMap<>();

    for (Map.Entry<Integer, Integer> entry : frame.oldToNewOutputs.entrySet()) {
      combinedMap.put(entry.getKey(),
          requireNonNull(mapNewInputToProjOutputs.get(entry.getValue()),
              () -> "mapNewInputToProjOutputs.get(" + entry.getValue() + ")"));
    }

    register(oldInput, newProject, combinedMap, corDefOutputs);

    // now it's time to rewrite the Aggregate
    final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount);
    List<AggregateCall> newAggCalls = new ArrayList<>();
    List<AggregateCall> oldAggCalls = rel.getAggCallList();

    final Iterable<ImmutableBitSet> newGroupSets;
    if (rel.getGroupType() == Aggregate.Group.SIMPLE) {
      newGroupSets = null;
    } else {
      final ImmutableBitSet addedGroupSet =
          ImmutableBitSet.range(oldGroupKeyCount, newGroupKeyCount);
      newGroupSets =
          ImmutableBitSet.ORDERING.immutableSortedCopy(
              Util.transform(rel.getGroupSets(),
                  bitSet -> bitSet.union(addedGroupSet)));
    }

    int oldInputOutputFieldCount = rel.getGroupSet().cardinality();
    int newInputOutputFieldCount = newGroupSet.cardinality();

    int i = -1;
    for (AggregateCall oldAggCall : oldAggCalls) {
      ++i;
      List<Integer> oldAggArgs = oldAggCall.getArgList();

      List<Integer> aggArgs = new ArrayList<>();

      // Adjust the Aggregate argument positions.
      // Note Aggregate does not change input ordering, so the input
      // output position mapping can be used to derive the new positions
      // for the argument.
      for (int oldPos : oldAggArgs) {
        aggArgs.add(
            requireNonNull(combinedMap.get(oldPos),
                () -> "combinedMap.get(" + oldPos + ")"));
      }
      final int filterArg =
          oldAggCall.filterArg < 0 ? oldAggCall.filterArg
              : requireNonNull(combinedMap.get(oldAggCall.filterArg),
                  () -> "combinedMap.get(" + oldAggCall.filterArg + ")");

      boolean newHasEmptyGroup = newGroupSets == null && newGroupSet.isEmpty();
      if (newGroupSets != null) {
        Iterator<ImmutableBitSet> groupSetsIterator = newGroupSets.iterator();
        while (!newHasEmptyGroup && groupSetsIterator.hasNext()) {
          newHasEmptyGroup |= groupSetsIterator.next().isEmpty();
        }
      }
      newAggCalls.add(
          oldAggCall.adaptTo(newProject, aggArgs, filterArg,
              rel.hasEmptyGroup(), newHasEmptyGroup));

      // The old to new output position mapping will be the same as that
      // of newProject, plus any aggregates that the oldAgg produces.
      outputMap.put(
          oldInputOutputFieldCount + i,
          newInputOutputFieldCount + i);
    }

    relBuilder.push(newProject)
        .aggregate(newGroupSets == null
                ? relBuilder.groupKey(newGroupSet)
                : relBuilder.groupKey(newGroupSet, newGroupSets),
            newAggCalls);

    if (!omittedConstants.isEmpty()) {
      final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
      for (Map.Entry<Integer, RexLiteral> entry
          : omittedConstants.descendingMap().entrySet()) {
        int index = entry.getKey() + frame.corDefOutputs.size();
        postProjects.add(index, entry.getValue());
        // Shift the outputs whose index equals with or bigger than the added index
        // with 1 offset.
        shiftMapping(outputMap, index, 1);
        // Then add the constant key mapping.
        outputMap.put(entry.getKey(), index);
      }
      relBuilder.project(postProjects);
    }

    RelNode newRel = relBuilder.build();

    for (AggregateCall aggCall : rel.getAggCallList()) {
      if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
        parentPropagatesNullValues = false;
        break;
      }
    }

    if (rel.getGroupType() == Aggregate.Group.SIMPLE
        && rel.getGroupSet().isEmpty()
        && !frame.corDefOutputs.isEmpty()
        && !parentPropagatesNullValues) {
      newRel = rewriteScalarAggregate(rel, newRel, outputMap, corDefOutputs);
    }

    // Aggregate does not change input ordering so corVars will be
    // located at the same position as the input newProject.
    return register(rel, newRel, outputMap, corDefOutputs);
  }

  /**
   * Special case where the group by is static (i.e., aggregation functions without group by).
   *
   * <p>Background:
   *   For the query:
   *     SELECT SUM(salary), COUNT(name) FROM A;
   *   When table A is empty, it returns [null, 0].
   *   But for
   *     SELECT SUM(salary), COUNT(name) FROM A group by id
   *   When table A is empty, it returns empty. This causes result mismatch.
   * In the general decorrelation framework, we add corVar as an additional groupKey to
   * rewrite Correlate as JOIN. (See the code above for details) This means that when the input
   * is empty, the result produced using a JOIN is incorrect.
   *
   * <p>We refer to this situation as: `The well-known count bug`,
   * More details about this issue: Optimization of Nested SQL Queries Revisited
   * (https://dl.acm.org/doi/pdf/10.1145/38714.38723)
   *
   * <p>To handle this situation, we using a LEFT JOIN to ensure that an output is always produced.
   *
   * <p>Given the SQL:
   *   SELECT deptno FROM dept d
   *     WHERE 0 = (SELECT COUNT(*) FROM emp e WHERE d.deptno = e.deptno)
   * Corresponding plan:
   *    LogicalProject(DEPTNO=[$0])
   *      LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}])
   *        LogicalProject(DEPTNO=[$0])
   *          LogicalTableScan(table=[[scott, DEPT]])
   *        LogicalProject(cs=[true])
   *          LogicalFilter(condition=[=(0, $0)])
   *            LogicalAggregate(group=[{}], EXPR$0=[COUNT()])
   *              LogicalFilter(condition=[=($cor0.DEPTNO, $7)])
   *                LogicalTableScan(table=[[scott, EMP]])
   *
   * <p>Rewriting this as:
   *   SELECT d.deptno FROM dept d
   *     JOIN (
   *         SELECT true, e.deptno FROM emp e WHERE e.deptno IS NOT NULL
   *         GROUP BY e.deptno HAVING COUNT(*) = 0
   *     ) AS d0 ON d.deptno = d0.deptno
   * produces an incorrect result.
   * Corresponding plan:
   *    LogicalProject(DEPTNO=[$0])
   *      LogicalJoin(condition=[=($0, $2)], joinType=[inner])
   *        LogicalProject(DEPTNO=[$0])
   *          LogicalTableScan(table=[[scott, DEPT]])
   *        LogicalProject(cs=[true], DEPTNO=[$0])
   *          LogicalFilter(condition=[=(0, $1)])
   *            LogicalAggregate(group=[{0}], EXPR$0=[COUNT()]) // corresponds to {@code oldRel}
   *              LogicalProject(DEPTNO=[$7])
   *                LogicalFilter(condition=[IS NOT NULL($7)])
   *                  LogicalTableScan(table=[[scott, EMP]])
   *  We can clearly observe that due to the presence of the GROUP BY clause,
   *  COUNT(*) = 0 will never evaluate to true, since rows with zero records won't appear
   *  in the GROUP BY results. This produced incorrect results.
   *
   * <p>Rewrite Aggregate as:
   *   SELECT d.deptno FROM dept d
   *     JOIN (
   *          SELECT true AS cs, deptno
   *          FROM (
   *              SELECT d2.deptno,
   *                     CASE WHEN cnt0 IS NOT NULL THEN cnt0 ELSE 0 END AS cnt
   *              FROM (SELECT deptno FROM dept GROUP BY deptno) d2
   *              LEFT JOIN (
   *                  SELECT deptno, COUNT(e.empno) cnt0
   *                  FROM emp
   *                  WHERE deptno IS NOT NULL
   *                  GROUP BY deptno) e
   *              ON d2.deptno IS NOT DISTINCT FROM e.deptno
   *          ) AS case_count
   *          WHERE cnt = 0
   *     ) AS d0 ON d.deptno = d0.deptno
   * Corresponding plan:
   * [01] LogicalProject(DEPTNO=[$0])
   * [02]   LogicalJoin(condition=[=($0, $2)], joinType=[inner])
   * [03]     LogicalProject(DEPTNO=[$0])
   * [04]       LogicalTableScan(table=[[scott, DEPT]])
   * [05]     LogicalProject(cs=[true], DEPTNO=[$0])
   * [06]       LogicalFilter(condition=[=(0, $1)])
   * [07]         LogicalProject(DEPTNO=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)])
   * [08]           LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[left])
   * [09]             LogicalAggregate(group=[{0}])
   * [10]               LogicalProject(DEPTNO=[$0])
   * [11]                 LogicalTableScan(table=[[scott, DEPT]])
   * [12]             LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])
   * [13]               LogicalProject(DEPTNO=[$7])
   * [14]                 LogicalFilter(condition=[IS NOT NULL($7)])
   * [15]                   LogicalTableScan(table=[[scott, EMP]])
   *
   * <p>Here we perform an early join, preserving all possible CorVar sets from the outer scope
   * and their corresponding aggregation results. This ensures that for any row from the left
   * input of the Correlation, there is always an aggregation result available for join output.
   *
   * <p>Implementation based on: Improving Unnesting of Complex Queries
   * (https://dl.gi.de/server/api/core/bitstreams/c1918e8c-6a87-4da2-930a-bfed289f2388/content)
   */
  private RelNode rewriteScalarAggregate(Aggregate oldRel,
      RelNode newRel,
      Map<Integer, Integer> outputMap,
      NavigableMap<CorDef, Integer> corDefOutputs) {
    final Pair<CorrelationId, Frame> outerFramePair = requireNonNull(this.frameStack.peek());
    final Frame outFrame = outerFramePair.right;
    RexBuilder rexBuilder = relBuilder.getRexBuilder();

    int groupKeySize = (int) corDefOutputs.keySet().stream()
        .filter(a -> a.corr.equals(outerFramePair.left))
        .count();
    List<RelDataTypeField> newRelFields = newRel.getRowType().getFieldList();
    ImmutableBitSet.Builder corFieldBuilder = ImmutableBitSet.builder();

    // Here we record the mapping between the original index and the new project.
    // For the count, we map it as `case when x is null then 0 else x`.
    final Map<Integer, RexNode> newProjectMap = new HashMap<>();
    final List<RexNode> conditions = new ArrayList<>();
    for (Map.Entry<CorDef, Integer> corDefOutput : corDefOutputs.entrySet()) {
      CorDef corDef = corDefOutput.getKey();
      Integer corIndex = corDefOutput.getValue();
      if (corDef.corr.equals(outerFramePair.left)) {
        int newIdx = requireNonNull(outFrame.oldToNewOutputs.get(corDef.field));
        corFieldBuilder.set(newIdx);

        RelDataType type = outFrame.r.getRowType().getFieldList().get(newIdx).getType();
        RexNode left = new RexInputRef(corFieldBuilder.cardinality() - 1, type);
        newProjectMap.put(corIndex + groupKeySize, left);
        conditions.add(
            relBuilder.isNotDistinctFrom(left,
                new RexInputRef(corIndex + groupKeySize,
                    newRelFields.get(corIndex).getType())));
      }
    }

    ImmutableBitSet groupSet = corFieldBuilder.build();
    // Build [09] LogicalAggregate(group=[{0}]) to obtain the distinct set of
    // corVar from outFrame.
    relBuilder.push(outFrame.r)
        .aggregate(relBuilder.groupKey(groupSet));

    // Build [08] LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[left])
    // to ensure each corVar's aggregate result is output.
    final RelNode join = relBuilder.push(newRel)
        .join(JoinRelType.LEFT, conditions).build();

    for (int i1 = 0; i1 < oldRel.getAggCallList().size(); i1++) {
      AggregateCall aggCall = oldRel.getAggCallList().get(i1);
      if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
        int index = requireNonNull(outputMap.get(i1 + oldRel.getGroupSet().size()));
        final RexInputRef ref = RexInputRef.of(index + groupKeySize, join.getRowType());
        RexNode specificCountValue =
            rexBuilder.makeCall(SqlStdOperatorTable.CASE,
                ImmutableList.of(relBuilder.isNotNull(ref), ref, relBuilder.literal(0)));
        newProjectMap.put(ref.getIndex(), specificCountValue);
      }
    }

    final List<RexNode> newProjects = new ArrayList<>();
    for (int index : ImmutableBitSet.range(groupKeySize, join.getRowType().getFieldCount())) {
      if (newProjectMap.containsKey(index)) {
        newProjects.add(requireNonNull(newProjectMap.get(index)));
      } else {
        newProjects.add(RexInputRef.of(index, join.getRowType()));
      }
    }

    // Build [07] LogicalProject(DEPTNO=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)])
    // to handle COUNT function by converting nulls to zero.
    return relBuilder.push(join)
        .project(newProjects, newRel.getRowType().getFieldNames())
        .build();
  }

  /**
   * Shift the mapping to fixed offset from the {@code startIndex}.
   *
   * @param mapping    The original mapping
   * @param startIndex Any output whose index equals with or bigger than the starting index
   *                   would be shift
   * @param offset     Shift offset
   */
  private static void shiftMapping(Map<Integer, Integer> mapping, int startIndex, int offset) {
    for (Map.Entry<Integer, Integer> entry : mapping.entrySet()) {
      if (entry.getValue() >= startIndex) {
        entry.setValue(entry.getValue() + offset);
      }
    }
  }

  /**
   * Invokes decorrelation logic for a given relational expression.
   *
   * @param parentPropagatesNullValues True if the parent RelNode produces null
   *                                   when all of its inputs fields are null.
   */
  public @Nullable Frame getInvoke(RelNode r, boolean isCorVarDefined,
      @Nullable RelNode parent, boolean parentPropagatesNullValues) {
    final Frame frame = dispatcher.invoke(r, isCorVarDefined, parentPropagatesNullValues);
    currentRel = parent;
    if (frame != null) {
      map.put(r, frame);
    }
    return frame;
  }

  /** Returns a literal output field, or null if it is not literal. */
  private static @Nullable RexLiteral projectedLiteral(RelNode rel, int i) {
    if (rel instanceof Project) {
      final Project project = (Project) rel;
      final RexNode node = project.getProjects().get(i);
      if (node instanceof RexLiteral) {
        return (RexLiteral) node;
      }
    }
    return null;
  }

  protected @Nullable Frame decorrelateFetchOneSort(Sort sort, final Frame frame) {
    Frame aggFrame = decorrelateSortAsAggregate(sort, frame);
    if (aggFrame != null) {
      return aggFrame;
    }
    //
    // Rewrite logic:
    //
    // If sorted without offset and fetch = 1 (enforced by the caller), rewrite the sort to be
    //   Aggregate(group=(corVar.. , field..))
    //     project(first_value(field) over (partition by corVar order by (sort collation)))
    //       input
    //
    // 1. For the original sorted input, apply the FIRST_VALUE window function to produce
    //    the result of sorting with LIMIT 1, and the same as the decorrelate of aggregate,
    //    add correlated variables in partition list to maintain semantic consistency.
    // 2. To ensure that there is at most one row of output for
    //    any combination of correlated variables, distinct for correlated variables.
    // 3. Since we have partitioned by all correlated variables
    //    in the sorted output field window, so for any combination of correlated variables,
    //    all other field values are unique. So the following two are equivalent:
    //      - group by corVar1, covVar2, field1, field2
    //      - any_value(fields1), any_value(fields2) group by corVar1, covVar2
    //    Here we use the first.
    final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
    final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();

    final PairList<RexNode, String> corVarProjects = PairList.of();
    List<RelDataTypeField> fieldList = frame.r.getRowType().getFieldList();
    for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
      corDefOutputs.put(entry.getKey(),
          sort.getRowType().getFieldCount() + corVarProjects.size());
      RexInputRef.add2(corVarProjects, entry.getValue(), fieldList);
    }

    final List<RexNode> sortExprs =
        new ArrayList<>(sort.getCollation().getFieldCollations().size());
    for (RelFieldCollation collation : sort.getCollation().getFieldCollations()) {
      Integer newIdx = requireNonNull(frame.oldToNewOutputs.get(collation.getFieldIndex()));
      RexNode node = RexInputRef.of(newIdx, fieldList);
      if (collation.direction == RelFieldCollation.Direction.DESCENDING) {
        node = relBuilder.desc(node);
      }
      if (collation.nullDirection == RelFieldCollation.NullDirection.FIRST) {
        node = relBuilder.nullsFirst(node);
      } else if (collation.nullDirection == RelFieldCollation.NullDirection.LAST) {
        node = relBuilder.nullsLast(node);
      }
      sortExprs.add(node);
    }

    final PairList<RexNode, String> newProjExprs = PairList.of();
    for (RelDataTypeField field : sort.getRowType().getFieldList()) {
      final int newIdx =
          requireNonNull(frame.oldToNewOutputs.get(field.getIndex()));

      RelBuilder.AggCall aggCall =
          relBuilder.aggregateCall(SqlStdOperatorTable.FIRST_VALUE,
              RexInputRef.of(newIdx, fieldList));

      // Convert each field from the sorted output to a window function that partitions by
      // correlated variables, orders by the collation, and return the first_value.
      RexNode winCall = aggCall.over()
          .orderBy(sortExprs)
          .partitionBy(corVarProjects.leftList())
          .toRex();
      mapOldToNewOutputs.put(newProjExprs.size(), newProjExprs.size());
      newProjExprs.add(winCall, field.getName());
    }
    newProjExprs.addAll(corVarProjects);
    RelNode result = relBuilder.push(frame.r)
        .project(newProjExprs.leftList(), newProjExprs.rightList())
        .distinct().build();

    return register(sort, result, mapOldToNewOutputs, corDefOutputs);
  }

  protected @Nullable Frame decorrelateSortAsAggregate(Sort sort, final Frame frame) {
    final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
    final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
    if (sort.getCollation().getFieldCollations().size() == 1
        && sort.getRowType().getFieldCount() == 1
        && !frame.corDefOutputs.isEmpty()) {
      //
      // Rewrite logic:
      //
      // If sorted with no OFFSET and FETCH = 1, and only one collation field,
      // rewrite the Sort as Aggregate using MIN/MAX function.
      // Example:
      //  Sort(sort0=[$0], dir0=[ASC], fetch=[1])
      //   input
      // Rewrite to:
      //  Aggregate(group=(corVar), agg=[min($0))
      //
      // Note: MIN/MAX is not strictly equivalent to LIMIT 1. When the input has 0 rows,
      // MIN/MAX returns NULL, while LIMIT 1 returns 0 rows.
      // However, in the decorrelate, we add correlated variables to the group list
      // to ensure equivalence when Correlate is transformed to Join. When the group list
      // is non-empty, MIN/MAX will also return 0 rows if the input has 0 rows.
      // So in this case, the transformation is legal.
      RelFieldCollation collation = Util.first(sort.getCollation().getFieldCollations());

      if (collation.nullDirection != RelFieldCollation.NullDirection.LAST) {
        return null;
      }

      SqlAggFunction aggFunction;
      switch (collation.getDirection()) {
      case ASCENDING:
      case STRICTLY_ASCENDING:
        aggFunction = SqlStdOperatorTable.MIN;
        break;
      case DESCENDING:
      case STRICTLY_DESCENDING:
        aggFunction = SqlStdOperatorTable.MAX;
        break;
      default:
        return null;
      }

      final int newIdx = requireNonNull(frame.oldToNewOutputs.get(collation.getFieldIndex()));
      RelBuilder.AggCall aggCall = relBuilder.push(frame.r)
          .aggregateCall(aggFunction, relBuilder.fields(ImmutableList.of(newIdx)));

      // As with the aggregate decorrelate, add correlated variables to the group list.
      final List<RexInputRef> groupKey = new ArrayList<>();
      for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
        groupKey.add(RexInputRef.of(entry.getValue(), frame.r.getRowType()));
        corDefOutputs.put(entry.getKey(), corDefOutputs.size());
      }

      RelNode aggregate = relBuilder.aggregate(relBuilder.groupKey(groupKey), aggCall).build();

      // Add the mapping for the added aggregate fields.
      mapOldToNewOutputs.put(0, groupKey.size());
      return register(sort, aggregate, mapOldToNewOutputs, corDefOutputs);
    }
    return null;
  }

  public @Nullable Frame decorrelateRel(LogicalProject rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    return decorrelateRel((Project) rel, isCorVarDefined, parentPropagatesNullValues);
  }

  public @Nullable Frame decorrelateRel(Project rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    //
    // Rewrite logic:
    //
    // 1. Pass along any correlated variables coming from the input.
    //
    for (RexNode project : rel.getProjects()) {
      if (!Strong.isStrong(project)) {
        parentPropagatesNullValues = false;
      }
    }

    final RelNode oldInput = rel.getInput();
    Frame frame = getInvoke(oldInput, isCorVarDefined, rel, parentPropagatesNullValues);
    if (frame == null) {
      // If input has not been rewritten, do not rewrite this rel.
      return null;
    }
    final List<RexNode> oldProjects = rel.getProjects();
    final List<RelDataTypeField> relOutput = rel.getRowType().getFieldList();

    // Project projects the original expressions,
    // plus any correlated variables the input wants to pass along.
    final PairList<RexNode, String> projects = PairList.of();

    // If this Project has correlated reference, create value generator
    // and produce the correlated variables in the new output.
    if (cm.mapRefRelToCorRef.containsKey(rel)) {
      frame = decorrelateInputWithValueGenerator(rel, frame);
    }

    // Project projects the original expressions
    final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();
    int newPos;
    for (newPos = 0; newPos < oldProjects.size(); newPos++) {
      projects.add(newPos,
          decorrelateExpr(requireNonNull(currentRel, "currentRel"),
              map, cm, oldProjects.get(newPos)),
          relOutput.get(newPos).getName());
      mapOldToNewOutputs.put(newPos, newPos);
    }

    // Project any correlated variables the input wants to pass along.
    final NavigableMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
    for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
      RexInputRef.add2(projects, entry.getValue(),
          frame.r.getRowType().getFieldList());
      corDefOutputs.put(entry.getKey(), newPos);
      newPos++;
    }

    RelNode newProject = relBuilder.push(frame.r)
        .projectNamed(projects.leftList(), projects.rightList(), true)
        .build();

    return register(rel, newProject, mapOldToNewOutputs, corDefOutputs);
  }

  /**
   * Create RelNode tree that produces a list of correlated variables.
   *
   * @param correlations         correlated variables to generate
   * @param valueGenFieldOffset  offset in the output that generated columns
   *                             will start
   * @param corDefOutputs        output positions for the correlated variables
   *                             generated
   * @return RelNode the root of the resultant RelNode tree
   */
  private @Nullable RelNode createValueGenerator(
      Iterable<CorRef> correlations,
      int valueGenFieldOffset,
      NavigableMap<CorDef, Integer> corDefOutputs) {
    final Map<RelNode, List<Integer>> mapNewInputToOutputs = new HashMap<>();

    final Map<RelNode, Integer> mapNewInputToNewOffset = new HashMap<>();

    // Input provides the definition of a correlated variable.
    // Add to map all the referenced positions (relative to each input rel).
    for (CorRef corVar : correlations) {
      final int oldCorVarOffset = corVar.field;

      final RelNode oldInput = requireNonNull(getCorRel(corVar));
      final Frame frame = requireNonNull(getOrCreateFrame(oldInput));
      final RelNode newInput = frame.r;

      final List<Integer> newLocalOutputs;
      if (!mapNewInputToOutputs.containsKey(newInput)) {
        newLocalOutputs = new ArrayList<>();
      } else {
        newLocalOutputs = mapNewInputToOutputs.get(newInput);
      }

      final int newCorVarOffset =
          requireNonNull(frame.oldToNewOutputs.get(oldCorVarOffset));

      // Add all unique positions referenced.
      if (!newLocalOutputs.contains(newCorVarOffset)) {
        newLocalOutputs.add(newCorVarOffset);
      }
      mapNewInputToOutputs.put(newInput, newLocalOutputs);
    }

    int offset = 0;

    // Project only the correlated fields out of each input
    // and join the project together.
    // To make sure the plan does not change in terms of join order,
    // join these rels based on their occurrence in corVar list which
    // is sorted.
    final Set<RelNode> joinedInputs = new HashSet<>();

    RelNode r = null;
    for (CorRef corVar : correlations) {
      final RelNode oldInput = requireNonNull(getCorRel(corVar));
      final RelNode newInput = requireNonNull(getOrCreateFrame(oldInput).r);

      if (!joinedInputs.contains(newInput)) {
        final List<Integer> positions =
            requireNonNull(mapNewInputToOutputs.get(newInput),
                () -> "mapNewInputToOutputs.get(" + newInput + ")");

        RelNode distinct = relBuilder.push(newInput)
            .project(relBuilder.fields(positions))
            .distinct()
            .build();
        RelOptCluster cluster = distinct.getCluster();

        joinedInputs.add(newInput);
        mapNewInputToNewOffset.put(newInput, offset);
        offset += distinct.getRowType().getFieldCount();

        if (r == null) {
          r = distinct;
        } else {
          r = relBuilder.push(r).push(distinct)
              .join(JoinRelType.INNER, cluster.getRexBuilder().makeLiteral(true)).build();
        }
      }
    }

    // Translate the positions of correlated variables to be relative to
    // the join output, leaving room for valueGenFieldOffset because
    // valueGenerators are joined with the original left input of the rel
    // referencing correlated variables.
    for (CorRef corRef : correlations) {
      // The first input of a Correlate is always the rel defining
      // the correlated variables.
      final RelNode oldInput = requireNonNull(getCorRel(corRef));
      final Frame frame = getOrCreateFrame(oldInput);
      final RelNode newInput = requireNonNull(frame.r);

      final List<Integer> newLocalOutputs =
          requireNonNull(mapNewInputToOutputs.get(newInput),
              () -> "mapNewInputToOutputs.get(" + newInput + ")");

      final int newLocalOutput = requireNonNull(frame.oldToNewOutputs.get(corRef.field));

      // newOutput is the index of the corVar in the referenced
      // position list plus the offset of referenced position list of
      // each newInput.
      final int newOutput =
          newLocalOutputs.indexOf(newLocalOutput)
              + requireNonNull(mapNewInputToNewOffset.get(newInput),
                  () -> "mapNewInputToNewOffset.get(" + newInput + ")")
              + valueGenFieldOffset;

      corDefOutputs.put(corRef.def(), newOutput);
    }

    return r;
  }

  private Frame getOrCreateFrame(RelNode r) {
    final Frame frame = getFrame(r);
    if (frame == null) {
      return new Frame(r, r, ImmutableSortedMap.of(),
          identityMap(r.getRowType().getFieldCount()));
    }
    return frame;
  }

  private @Nullable Frame getFrame(RelNode r) {
    return map.get(r);
  }

  private RelNode getCorRel(CorRef corVar) {
    final RelNode r =
        requireNonNull(cm.mapCorToCorRel.get(corVar.corr),
            () -> "cm.mapCorToCorRel.get(" + corVar.corr + ")");
    return requireNonNull(r.getInput(0),
        () -> "r.getInput(0) is null for " + r);
  }

  /** Adds a value generator to satisfy the correlating variables used by
   * a relational expression, if those variables are not already provided by
   * its input. */
  private Frame maybeAddValueGenerator(RelNode rel, Frame frame) {
    final CorelMap cm1 = new CorelMapBuilder().build(frame.r, rel);
    if (!cm1.mapRefRelToCorRef.containsKey(rel)) {
      return frame;
    }
    final Collection<CorRef> needs = cm1.mapRefRelToCorRef.get(rel);
    final ImmutableSortedSet<CorDef> haves = frame.corDefOutputs.keySet();
    if (hasAll(needs, haves)) {
      return frame;
    }
    return decorrelateInputWithValueGenerator(rel, frame);
  }

  /** Returns whether all of a collection of {@link CorRef}s are satisfied
   * by at least one of a collection of {@link CorDef}s. */
  private static boolean hasAll(Collection<CorRef> corRefs,
      Collection<CorDef> corDefs) {
    for (CorRef corRef : corRefs) {
      if (!has(corDefs, corRef)) {
        return false;
      }
    }
    return true;
  }

  /** Returns whether a {@link CorrelationId} is satisfied by at least one of a
   * collection of {@link CorDef}s. */
  private static boolean has(Collection<CorDef> corDefs, CorRef corr) {
    for (CorDef corDef : corDefs) {
      if (corDef.corr.equals(corr.corr) && corDef.field == corr.field) {
        return true;
      }
    }
    return false;
  }

  private Frame decorrelateInputWithValueGenerator(RelNode rel, Frame frame) {
    // currently only handles one input
    assert rel.getInputs().size() == 1;
    RelNode oldInput = frame.r;

    final NavigableMap<CorDef, Integer> corDefOutputs =
        new TreeMap<>(frame.corDefOutputs);

    final Collection<CorRef> corVarList = cm.mapRefRelToCorRef.get(rel);

    // Try to populate correlation variables using local fields.
    // This means that we do not need a value generator.
    if (rel instanceof Filter) {
      NavigableMap<CorDef, Integer> map = new TreeMap<>();
      List<RexNode> projects = new ArrayList<>();
      for (CorRef correlation : corVarList) {
        final CorDef def = correlation.def();
        if (corDefOutputs.containsKey(def) || map.containsKey(def)) {
          continue;
        }
        try {
          findCorrelationEquivalent(correlation, ((Filter) rel).getCondition());
        } catch (Util.FoundOne e) {
          Object node = requireNonNull(e.getNode(), "e.getNode()");
          if (node instanceof RexInputRef) {
            map.put(def, ((RexInputRef) node).getIndex());
          } else {
            map.put(def,
                frame.r.getRowType().getFieldCount() + projects.size());
            projects.add((RexNode) node);
          }
        }
      }
      // If all correlation variables are now satisfied, skip creating a value
      // generator.
      if (map.size() == corVarList.size()) {
        map.putAll(frame.corDefOutputs);
        final RelNode r;
        if (!projects.isEmpty()) {
          relBuilder.push(oldInput)
              .project(Iterables.concat(relBuilder.fields(), projects));
          r = relBuilder.build();
        } else {
          r = oldInput;
        }
        return register(rel.getInput(0), r,
            frame.oldToNewOutputs, map);
      }
    }

    int leftInputOutputCount = frame.r.getRowType().getFieldCount();

    // can directly add positions into corDefOutputs since join
    // does not change the output ordering from the inputs.
    final RelNode valueGen =
        createValueGenerator(corVarList, leftInputOutputCount, corDefOutputs);
    requireNonNull(valueGen, "valueGen");

    RelNode join =
        relBuilder
            .push(frame.r)
            .push(valueGen)
            .join(JoinRelType.INNER)
            .build();

    // Join or Filter does not change the old input ordering. All
    // input fields from newLeftInput (i.e. the original input to the old
    // Filter) are in the output and in the same position.
    return register(rel.getInput(0), join, frame.oldToNewOutputs,
        corDefOutputs);
  }

  /** Finds a {@link RexInputRef} that is equivalent to a {@link CorRef},
   * and if found, throws a {@link org.apache.calcite.util.Util.FoundOne}.
   *
   * <p>The equivalent expression must not contain any {@link RexFieldAccess},
   * ensuring that we only map the correlation variable to a local field or
   * expression from the current relational expression (e.g., a {@link RexInputRef}),
   * rather than to another correlation variable.
   */
  private static void findCorrelationEquivalent(CorRef correlation, RexNode e)
      throws Util.FoundOne {
    switch (e.getKind()) {
    case EQUALS:
      final RexCall call = (RexCall) e;
      final List<RexNode> operands = call.getOperands();
      if (!RexUtil.containsFieldAccess(operands.get(1))
          && references(operands.get(0), correlation)) {
        throw new Util.FoundOne(operands.get(1));
      }
      if (!RexUtil.containsFieldAccess(operands.get(0))
          && references(operands.get(1), correlation)) {
        throw new Util.FoundOne(operands.get(0));
      }
      break;
    case AND:
      for (RexNode operand : ((RexCall) e).getOperands()) {
        findCorrelationEquivalent(correlation, operand);
      }
      break;
    default:
      break;
    }
  }

  private static boolean references(RexNode e, CorRef correlation) {
    switch (e.getKind()) {
    case CAST:
      final RexNode operand = ((RexCall) e).getOperands().get(0);
      if (isWidening(e.getType(), operand.getType())) {
        return references(operand, correlation);
      }
      return false;
    case FIELD_ACCESS:
      final RexFieldAccess f = (RexFieldAccess) e;
      if (f.getField().getIndex() == correlation.field
          && f.getReferenceExpr() instanceof RexCorrelVariable) {
        if (((RexCorrelVariable) f.getReferenceExpr()).id == correlation.corr) {
          return true;
        }
      }
      // fall through
    default:
      return false;
    }
  }

  /** Returns whether one type is just a widening of another.
   *
   * <p>For example:<ul>
   * <li>{@code VARCHAR(10)} is a widening of {@code VARCHAR(5)}.
   * <li>{@code VARCHAR(10)} is a widening of {@code VARCHAR(10) NOT NULL}.
   * </ul>
   */
  private static boolean isWidening(RelDataType type, RelDataType type1) {
    return type.getSqlTypeName() == type1.getSqlTypeName()
        && type.getPrecision() >= type1.getPrecision();
  }

  public @Nullable Frame decorrelateRel(LogicalSnapshot rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    if (RexUtil.containsCorrelation(rel.getPeriod())) {
      return null;
    }
    return decorrelateRel((RelNode) rel, isCorVarDefined, parentPropagatesNullValues);
  }

  public @Nullable Frame decorrelateRel(LogicalTableFunctionScan rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    if (RexUtil.containsCorrelation(rel.getCall())) {
      return null;
    }
    return decorrelateRel((RelNode) rel, isCorVarDefined, parentPropagatesNullValues);
  }

  public @Nullable Frame decorrelateRel(LogicalFilter rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    return decorrelateRel((Filter) rel, isCorVarDefined, parentPropagatesNullValues);
  }

  public @Nullable Frame decorrelateRel(Filter rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    //
    // Rewrite logic:
    //
    // 1. If a Filter references a correlated field in its filter
    // condition, rewrite the Filter to be
    //   Filter
    //     Join(cross product)
    //       originalFilterInput
    //       ValueGenerator(produces distinct sets of correlated variables)
    // and rewrite the correlated fieldAccess in the filter condition to
    // reference the Join output.
    //
    // 2. If Filter does not reference correlated variables, simply
    // rewrite the filter condition using new input.
    //

    final RelNode oldInput = rel.getInput();
    Frame frame = getInvoke(oldInput, isCorVarDefined, rel, parentPropagatesNullValues);
    if (frame == null) {
      // If input has not been rewritten, do not rewrite this rel.
      return null;
    }

    // If this Filter has correlated reference, create value generator
    // and produce the correlated variables in the new output.
    if (false) {
      if (cm.mapRefRelToCorRef.containsKey(rel)) {
        frame = decorrelateInputWithValueGenerator(rel, frame);
      }
    } else {
      frame = maybeAddValueGenerator(rel, frame);
    }

    final CorelMap cm2 = new CorelMapBuilder().build(rel);

    // Replace the filter expression to reference output of the join
    // Map filter to the new filter over join
    relBuilder.push(frame.r)
        .filter(decorrelateExpr(castNonNull(currentRel), map, cm2, rel.getCondition()));

    // Filter does not change the input ordering.
    // Filter rel does not permute the input.
    // All corVars produced by filter will have the same output positions in the
    // input rel.
    return register(rel, relBuilder.build(), frame.oldToNewOutputs,
        frame.corDefOutputs);
  }

  public @Nullable Frame decorrelateRel(LogicalCorrelate rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    return decorrelateRel((Correlate) rel, isCorVarDefined, parentPropagatesNullValues);
  }

  public @Nullable Frame decorrelateRel(Correlate rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    //
    // Rewrite logic:
    //
    // The original left input will be joined with the new right input that
    // has generated correlated variables propagated up. For any generated
    // corVars that are not used in the join key, pass them along to be
    // joined later with the Correlates that produce them.
    //

    // the right input to Correlate should produce correlated variables
    final RelNode oldLeft = rel.getInput(0);
    final RelNode oldRight = rel.getInput(1);

    final Frame leftFrame = getInvoke(oldLeft, isCorVarDefined, rel, parentPropagatesNullValues);
    if (leftFrame == null) {
      // If input has not been rewritten, do not rewrite this rel.
      return null;
    }

    frameStack.push(Pair.of(rel.getCorrelationId(), leftFrame));
    final Frame rightFrame =
        getInvoke(oldRight, true, rel,
            rel.getJoinType() == JoinRelType.LEFT || parentPropagatesNullValues);
    frameStack.pop();

    if (rightFrame == null || rightFrame.corDefOutputs.isEmpty()) {
      return null;
    }

    assert rel.getRequiredColumns().cardinality()
        <= rightFrame.corDefOutputs.keySet().size();

    // Change correlator rel into a join.
    // Join all the correlated variables produced by this correlator rel
    // with the values generated and propagated from the right input
    final NavigableMap<CorDef, Integer> corDefOutputs =
        new TreeMap<>(rightFrame.corDefOutputs);
    final List<RexNode> conditions = new ArrayList<>();
    final List<RelDataTypeField> newLeftOutput =
        leftFrame.r.getRowType().getFieldList();
    int newLeftFieldCount = newLeftOutput.size();

    final List<RelDataTypeField> newRightOutput =
        rightFrame.r.getRowType().getFieldList();

    for (Map.Entry<CorDef, Integer> rightOutput
        : new ArrayList<>(corDefOutputs.entrySet())) {
      final CorDef corDef = rightOutput.getKey();
      if (!corDef.corr.equals(rel.getCorrelationId())) {
        continue;
      }
      final int newLeftPos = requireNonNull(leftFrame.oldToNewOutputs.get(corDef.field));
      final int newRightPos = rightOutput.getValue();

      // Using `equals` instead of `IS NOT DISTINCT FROM` is an optimization
      // for non-nullable fields. However, `IS NOT DISTINCT FROM` is always
      // the correct choice in all cases.
      if (isFieldNotNull(rightFrame.r, newRightPos)) {
        conditions.add(
            relBuilder.equals(RexInputRef.of(newLeftPos, newLeftOutput),
                new RexInputRef(newLeftFieldCount + newRightPos,
                    newRightOutput.get(newRightPos).getType())));
      } else {
        conditions.add(
            relBuilder.isNotDistinctFrom(RexInputRef.of(newLeftPos, newLeftOutput),
                new RexInputRef(newLeftFieldCount + newRightPos,
                    newRightOutput.get(newRightPos).getType())));
      }
      // remove this corVar from output position mapping
      corDefOutputs.remove(corDef);
    }

    // Update the output position for the corVars: only pass on the cor
    // vars that are not used in the join key.
    for (Map.Entry<CorDef, Integer> entry : corDefOutputs.entrySet()) {
      entry.setValue(entry.getValue() + newLeftFieldCount);
    }

    // then add any corVar from the left input. Do not need to change
    // output positions.
    corDefOutputs.putAll(leftFrame.corDefOutputs);

    // Create the mapping between the output of the old correlation rel
    // and the new join rel
    final Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();

    int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();

    int oldRightFieldCount = oldRight.getRowType().getFieldCount();
    //noinspection AssertWithSideEffects
    assert rel.getRowType().getFieldCount()
        == oldLeftFieldCount + oldRightFieldCount;

    // Left input positions are not changed.
    mapOldToNewOutputs.putAll(leftFrame.oldToNewOutputs);

    // Right input positions are shifted by newLeftFieldCount.
    for (int i = 0; i < oldRightFieldCount; i++) {
      mapOldToNewOutputs.put(i + oldLeftFieldCount,
          requireNonNull(rightFrame.oldToNewOutputs.get(i)) + newLeftFieldCount);
    }

    final RexNode condition =
        RexUtil.composeConjunction(relBuilder.getRexBuilder(), conditions);
    RelNode newJoin = relBuilder.push(leftFrame.r).push(rightFrame.r)
        .join(rel.getJoinType(), condition).build();

    return register(rel, newJoin, mapOldToNewOutputs, corDefOutputs);
  }

  public @Nullable Frame decorrelateRel(LogicalJoin rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    return decorrelateRel((Join) rel, isCorVarDefined, parentPropagatesNullValues);
  }

  public @Nullable Frame decorrelateRel(Join rel, boolean isCorVarDefined,
      boolean parentPropagatesNullValues) {
    // For SEMI/ANTI join decorrelate it's input directly,
    // because the correlate variables can only be propagated from
    // the left side, which is not supported yet.
    if (!rel.getJoinType().projectsRight()) {
      return decorrelateRel((RelNode) rel, isCorVarDefined, parentPropagatesNullValues);
    }
    //
    // Rewrite logic:
    //
    // 1. rewrite join condition.
    // 2. map output positions and produce corVars if any.
    //

    final RelNode oldLeft = rel.getInput(0);
    final RelNode oldRight = rel.getInput(1);

    final Frame leftFrame = getInvoke(oldLeft, isCorVarDefined, rel, parentPropagatesNullValues);
    final Frame rightFrame = getInvoke(oldRight, isCorVarDefined, rel, parentPropagatesNullValues);

    if (leftFrame == null || rightFrame == null) {
      // If any input has not been rewritten, do not rewrite this rel.
      return null;
    }

    RelNode newJoin = relBuilder
        .push(leftFrame.r)
        .push(rightFrame.r)
        .join(rel.getJoinType(),
            decorrelateExpr(castNonNull(currentRel), map, cm, rel.getCondition()),
            ImmutableSet.of())
        .build();

    // Create the mapping between the output of the old correlation rel
    // and the new join rel
    Map<Integer, Integer> mapOldToNewOutputs = new HashMap<>();

    int oldLeftFieldCount = oldLeft.getRowType().getFieldCount();
    int newLeftFieldCount = leftFrame.r.getRowType().getFieldCount();

    int oldRightFieldCount = oldRight.getRowType().getFieldCount();
    //noinspection AssertWithSideEffects
    assert rel.getRowType().getFieldCount()
        == oldLeftFieldCount + oldRightFieldCount;

    // Left input positions are not changed.
    mapOldToNewOutputs.putAll(leftFrame.oldToNewOutputs);

    // Right input positions are shifted by newLeftFieldCount.
    for (int i = 0; i < oldRightFieldCount; i++) {
      mapOldToNewOutputs.put(i + oldLeftFieldCount,
          requireNonNull(rightFrame.oldToNewOutputs.get(i)) + newLeftFieldCount);
    }

    final NavigableMap<CorDef, Integer> corDefOutputs =
        new TreeMap<>(leftFrame.corDefOutputs);

    // Right input positions are shifted by newLeftFieldCount.
    for (Map.Entry<CorDef, Integer> entry
        : rightFrame.corDefOutputs.entrySet()) {
      corDefOutputs.put(entry.getKey(),
          entry.getValue() + newLeftFieldCount);
    }
    return register(rel, newJoin, mapOldToNewOutputs, corDefOutputs);
  }

  private static RexInputRef getNewForOldInputRef(RelNode currentRel,
      Map<RelNode, Frame> map, RexInputRef oldInputRef) {
    requireNonNull(currentRel, "currentRel");

    int oldOrdinal = oldInputRef.getIndex();
    int newOrdinal = 0;

    // determine which input rel oldOrdinal references, and adjust
    // oldOrdinal to be relative to that input rel
    RelNode oldInput = null;

    for (RelNode oldInput0 : currentRel.getInputs()) {
      RelDataType oldInputType = oldInput0.getRowType();
      int n = oldInputType.getFieldCount();
      if (oldOrdinal < n) {
        oldInput = oldInput0;
        break;
      }
      RelNode newInput =
          requireNonNull(map.get(oldInput0),
              () -> "map.get(oldInput0) for " + oldInput0).r;
      newOrdinal += newInput.getRowType().getFieldCount();
      oldOrdinal -= n;
    }

    requireNonNull(oldInput, "oldInput");
    final Frame frame = requireNonNull(map.get(oldInput));

    // now oldOrdinal is relative to oldInput
    int oldLocalOrdinal = oldOrdinal;

    // figure out the newLocalOrdinal, relative to the newInput.
    int newLocalOrdinal = oldLocalOrdinal;

    if (!frame.oldToNewOutputs.isEmpty()) {
      newLocalOrdinal = requireNonNull(frame.oldToNewOutputs.get(oldLocalOrdinal));
    }

    newOrdinal += newLocalOrdinal;

    return new RexInputRef(newOrdinal,
        frame.r.getRowType().getFieldList().get(newLocalOrdinal).getType());
  }

  /**
   * Pulls project above the join from its RHS input. Enforces nullability
   * for join output.
   *
   * @param join          Join
   * @param project       Original project as the right-hand input of the join
   * @param nullIndicatorPos Position of null indicator
   * @return the subtree with the new Project at the root
   */
  private RelNode projectJoinOutputWithNullability(
      Join join,
      Project project,
      int nullIndicatorPos) {
    final RelDataTypeFactory typeFactory = join.getCluster().getTypeFactory();
    final RelNode left = join.getLeft();
    final JoinRelType joinType = join.getJoinType();

    RexInputRef nullIndicator =
        new RexInputRef(
            nullIndicatorPos,
            typeFactory.createTypeWithNullability(
                join.getRowType().getFieldList().get(nullIndicatorPos)
                    .getType(),
                true));

    // now create the new project
    final PairList<RexNode, String> newProjExprs = PairList.of();

    // project everything from the LHS and then those from the original
    // projRel
    List<RelDataTypeField> leftInputFields =
        left.getRowType().getFieldList();

    for (int i = 0; i < leftInputFields.size(); i++) {
      RexInputRef.add2(newProjExprs, i, leftInputFields);
    }

    // Marked where the projected expr is coming from so that the types will
    // become nullable for the original projections which are now coming out
    // of the nullable side of the OJ.
    boolean projectPulledAboveLeftCorrelator =
        joinType.generatesNullsOnRight();

    for (Pair<RexNode, String> pair : project.getNamedProjects()) {
      RexNode newProjExpr =
          removeCorrelationExpr(
              pair.left,
              projectPulledAboveLeftCorrelator,
              nullIndicator);
      newProjExprs.add(newProjExpr, pair.right);
    }

    return relBuilder.push(join)
        .projectNamed(newProjExprs.leftList(), newProjExprs.rightList(), true)
        .build();
  }

  /**
   * Pulls a {@link Project} above a {@link Correlate} from its RHS input.
   * Enforces nullability for join output.
   *
   * @param correlate  Correlate
   * @param project the original project as the RHS input of the join
   * @param isCount Positions which are calls to the <code>COUNT</code>
   *                aggregation function
   * @return the subtree with the new Project at the root
   */
  private RelNode aggregateCorrelatorOutput(
      Correlate correlate,
      Project project,
      Set<Integer> isCount) {
    final RelNode left = correlate.getLeft();
    final JoinRelType joinType = correlate.getJoinType();

    // now create the new project
    final PairList<RexNode, String> newProjects = PairList.of();

    // Project everything from the LHS and then those from the original
    // project
    final List<RelDataTypeField> leftInputFields =
        left.getRowType().getFieldList();

    for (int i = 0; i < leftInputFields.size(); i++) {
      RexInputRef.add2(newProjects, i, leftInputFields);
    }

    // Marked where the projected expr is coming from so that the types will
    // become nullable for the original projections which are now coming out
    // of the nullable side of the OJ.
    boolean projectPulledAboveLeftCorrelator =
        joinType.generatesNullsOnRight();

    for (Pair<RexNode, String> pair : project.getNamedProjects()) {
      RexNode newProjExpr =
          removeCorrelationExpr(
              pair.left,
              projectPulledAboveLeftCorrelator,
              isCount);
      newProjects.add(newProjExpr, pair.right);
    }

    return relBuilder.push(correlate)
        .projectNamed(newProjects.leftList(), newProjects.rightList(), true)
        .build();
  }

  /**
   * Checks whether the correlations in projRel and filter are related to
   * the correlated variables provided by corRel.
   *
   * @param correlate    Correlate
   * @param project   The original Project as the RHS input of the join
   * @param filter    Filter
   * @param correlatedJoinKeys Correlated join keys
   * @return true if filter and proj only references corVar provided by corRel
   */
  private boolean checkCorVars(
      Correlate correlate,
      @Nullable Project project,
      @Nullable Filter filter,
      @Nullable List<RexFieldAccess> correlatedJoinKeys) {
    if (filter != null) {
      requireNonNull(correlatedJoinKeys, "correlatedJoinKeys");

      // check that all correlated refs in the filter condition are
      // used in the join(as field access).
      Set<CorRef> corVarInFilter =
          Sets.newHashSet(cm.mapRefRelToCorRef.get(filter));

      for (RexFieldAccess correlatedJoinKey : correlatedJoinKeys) {
        corVarInFilter.remove(cm.mapFieldAccessToCorRef.get(correlatedJoinKey));
      }

      if (!corVarInFilter.isEmpty()) {
        return false;
      }

      // Check that the correlated variables referenced in these
      // comparisons do come from the Correlate.
      corVarInFilter.addAll(cm.mapRefRelToCorRef.get(filter));

      for (CorRef corVar : corVarInFilter) {
        if (cm.mapCorToCorRel.get(corVar.corr) != correlate) {
          return false;
        }
      }
    }

    // if project has any correlated reference, make sure they are also
    // provided by the current correlate. They will be projected out of the LHS
    // of the correlate.
    if ((project != null) && cm.mapRefRelToCorRef.containsKey(project)) {
      for (CorRef corVar : cm.mapRefRelToCorRef.get(project)) {
        if (cm.mapCorToCorRel.get(corVar.corr) != correlate) {
          return false;
        }
      }
    }

    return true;
  }

  /**
   * Removes correlated variables from the tree at root corRel.
   *
   * @param correlate Correlate
   */
  private void removeCorVarFromTree(Correlate correlate) {
    cm.mapCorToCorRel.remove(correlate.getCorrelationId(), correlate);
  }

  /**
   * Projects all {@code input} output fields plus the additional expressions.
   *
   * @param input        Input relational expression
   * @param additionalExprs Additional expressions and names
   * @return the new Project
   */
  private RelNode createProjectWithAdditionalExprs(
      RelNode input,
      PairList<RexNode, String> additionalExprs) {
    final List<RelDataTypeField> fieldList =
        input.getRowType().getFieldList();
    PairList<RexNode, String> projects = PairList.of();
    Ord.forEach(fieldList, (field, i) ->
        projects.add(
            relBuilder.getRexBuilder().makeInputRef(field.getType(), i),
            field.getName()));
    projects.addAll(additionalExprs);
    return relBuilder.push(input)
        .projectNamed(projects.leftList(), projects.rightList(), true)
        .build();
  }

  /* Returns an immutable map with the identity [0: 0, .., count-1: count-1]. */
  static Map<Integer, Integer> identityMap(int count) {
    ImmutableMap.Builder<Integer, Integer> builder = ImmutableMap.builder();
    for (int i = 0; i < count; i++) {
      builder.put(i, i);
    }
    return builder.build();
  }

  /** Registers a relational expression and the relational expression it became
   * after decorrelation. */
  Frame register(RelNode rel, RelNode newRel,
      Map<Integer, Integer> oldToNewOutputs,
      NavigableMap<CorDef, Integer> corDefOutputs) {
    newRel = RelOptUtil.copyRelHints(rel, newRel);
    final Frame frame = new Frame(rel, newRel, corDefOutputs, oldToNewOutputs);
    map.put(rel, frame);
    return frame;
  }

  static boolean allLessThan(Collection<Integer> integers, int limit,
      Litmus ret) {
    for (int value : integers) {
      if (value >= limit) {
        return ret.fail("out of range; value: {}, limit: {}", value, limit);
      }
    }
    return ret.succeed();
  }

  private static RelNode stripHep(RelNode rel) {
    return rel instanceof HepRelVertex ? rel.stripped() : rel;
  }

  //~ Inner Classes ----------------------------------------------------------

  /** Shuttle that decorrelates. */
  private static class DecorrelateRexShuttle extends RexShuttle {
    private final RelNode currentRel;
    private final Map<RelNode, Frame> map;
    private final CorelMap cm;

    private DecorrelateRexShuttle(RelNode currentRel,
        Map<RelNode, Frame> map, CorelMap cm) {
      this.currentRel = requireNonNull(currentRel, "currentRel");
      this.map = requireNonNull(map, "map");
      this.cm = requireNonNull(cm, "cm");
    }

    @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
      int newInputOutputOffset = 0;
      for (RelNode input : currentRel.getInputs()) {
        final Frame frame = map.get(input);

        if (frame != null) {
          // try to find in this input rel the position of corVar
          final CorRef corRef = cm.mapFieldAccessToCorRef.get(fieldAccess);

          if (corRef != null) {
            Integer newInputPos = frame.corDefOutputs.get(corRef.def());
            if (newInputPos != null) {
              // This input does produce the corVar referenced.
              return new RexInputRef(newInputPos + newInputOutputOffset,
                  frame.r.getRowType().getFieldList().get(newInputPos)
                      .getType());
            }
          }

          // this input does not produce the corVar needed
          newInputOutputOffset += frame.r.getRowType().getFieldCount();
        } else {
          // this input is not rewritten
          newInputOutputOffset += input.getRowType().getFieldCount();
        }
      }
      return fieldAccess;
    }

    @Override public RexNode visitInputRef(RexInputRef inputRef) {
      final RexInputRef ref = getNewForOldInputRef(currentRel, map, inputRef);
      if (ref.getIndex() == inputRef.getIndex()
          && ref.getType() == inputRef.getType()) {
        return inputRef; // re-use old object, to prevent needless expr cloning
      }
      return ref;
    }
  }

  /** Shuttle that removes correlations. */
  private class RemoveCorrelationRexShuttle extends RexShuttle {
    final RexBuilder rexBuilder;
    final RelDataTypeFactory typeFactory;
    final boolean projectPulledAboveLeftCorrelator;
    final @Nullable RexInputRef nullIndicator;
    final ImmutableSet<Integer> isCount;

    RemoveCorrelationRexShuttle(
        RexBuilder rexBuilder,
        boolean projectPulledAboveLeftCorrelator,
        @Nullable RexInputRef nullIndicator,
        Set<Integer> isCount) {
      this.projectPulledAboveLeftCorrelator =
          projectPulledAboveLeftCorrelator;
      this.nullIndicator = nullIndicator; // may be null
      this.isCount = ImmutableSet.copyOf(isCount);
      this.rexBuilder = rexBuilder;
      this.typeFactory = rexBuilder.getTypeFactory();
    }

    private RexNode createCaseExpression(
        RexInputRef nullInputRef,
        @Nullable RexLiteral lit,
        RexNode rexNode) {
      RexNode[] caseOperands = new RexNode[3];

      // Construct a CASE expression to handle the null indicator.
      //
      // This also covers the case where a left correlated sub-query
      // projects fields from outer relation. Since LOJ cannot produce
      // nulls on the LHS, the projection now need to make a nullable LHS
      // reference using a nullability indicator. If this this indicator
      // is null, it means the sub-query does not produce any value. As a
      // result, any RHS ref by this sub-query needs to produce null value.

      // WHEN indicator IS NULL
      caseOperands[0] =
          rexBuilder.makeCall(
              SqlStdOperatorTable.IS_NULL,
              new RexInputRef(
                  nullInputRef.getIndex(),
                  typeFactory.createTypeWithNullability(
                      nullInputRef.getType(),
                      true)));

      // THEN CAST(NULL AS newInputTypeNullable)
      caseOperands[1] =
          lit == null
              ? rexBuilder.makeNullLiteral(rexNode.getType())
              : rexBuilder.makeCast(rexNode.getType(), lit);

      // ELSE cast (newInput AS newInputTypeNullable) END
      caseOperands[2] =
          rexBuilder.makeCast(
              typeFactory.createTypeWithNullability(
                  rexNode.getType(),
                  true),
              rexNode);

      return rexBuilder.makeCall(
          SqlStdOperatorTable.CASE,
          caseOperands);
    }

    @Override public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
      if (cm.mapFieldAccessToCorRef.containsKey(fieldAccess)) {
        // if it is a corVar, change it to be input ref.
        CorRef corVar = cm.mapFieldAccessToCorRef.get(fieldAccess);

        // corVar offset should point to the leftInput of currentRel,
        // which is the Correlate.
        RexNode newRexNode =
            new RexInputRef(corVar.field, fieldAccess.getType());

        if (projectPulledAboveLeftCorrelator
            && (nullIndicator != null)) {
          // need to enforce nullability by applying an additional
          // cast operator over the transformed expression.
          newRexNode =
              createCaseExpression(nullIndicator, null, newRexNode);
        }
        return newRexNode;
      }
      return fieldAccess;
    }

    @Override public RexNode visitInputRef(RexInputRef inputRef) {
      if (currentRel instanceof Correlate) {
        // if this rel references corVar
        // and now it needs to be rewritten
        // it must have been pulled above the Correlate
        // replace the input ref to account for the LHS of the
        // Correlate
        final int leftInputFieldCount =
            ((Correlate) currentRel).getLeft().getRowType()
                .getFieldCount();
        RelDataType newType = inputRef.getType();

        if (projectPulledAboveLeftCorrelator) {
          newType =
              typeFactory.createTypeWithNullability(newType, true);
        }

        int pos = inputRef.getIndex();
        RexInputRef newInputRef =
            new RexInputRef(leftInputFieldCount + pos, newType);

        if (isCount.contains(pos)) {
          return createCaseExpression(
              newInputRef,
              rexBuilder.makeExactLiteral(BigDecimal.ZERO),
              newInputRef);
        } else {
          return newInputRef;
        }
      }
      return inputRef;
    }

    @Override public RexNode visitLiteral(RexLiteral literal) {
      // Use nullIndicator to decide whether to project null.
      // Do nothing if the literal is null or symbol.
      if (!RexUtil.isNull(literal)
          && projectPulledAboveLeftCorrelator
          && (nullIndicator != null)
          && !RexUtil.isSymbolLiteral(literal)) {
        return createCaseExpression(nullIndicator, null, literal);
      }
      return literal;
    }

    @Override public RexNode visitCall(final RexCall call) {
      RexNode newCall;

      boolean[] update = {false};
      List<RexNode> clonedOperands = visitList(call.operands, update);
      if (update[0]) {
        SqlOperator operator = call.getOperator();

        boolean isSpecialCast = false;
        if (operator instanceof SqlFunction) {
          SqlFunction function = (SqlFunction) operator;
          if (function.getKind() == SqlKind.CAST) {
            if (call.operands.size() < 2) {
              isSpecialCast = true;
            }
          }
        }

        final RelDataType newType;
        if (!isSpecialCast) {
          // TODO: ideally this only needs to be called if the result
          // type will also change. However, since that requires
          // support from type inference rules to tell whether a rule
          // decides return type based on input types, for now all
          // operators will be recreated with new type if any operand
          // changed, unless the operator has "built-in" type.
          newType = rexBuilder.deriveReturnType(operator, clonedOperands);
        } else {
          // Use the current return type when creating a new call, for
          // operators with return type built into the operator
          // definition, and with no type inference rules, such as
          // cast function with less than 2 operands.

          // TODO: Comments in RexShuttle.visitCall() mention other
          // types in this category. Need to resolve those together
          // and preferably in the base class RexShuttle.
          newType = call.getType();
        }
        newCall =
            rexBuilder.makeCall(
                call.getParserPosition(),
                newType,
                operator,
                clonedOperands);
      } else {
        newCall = call;
      }

      if (projectPulledAboveLeftCorrelator && (nullIndicator != null)) {
        return createCaseExpression(nullIndicator, null, newCall);
      }
      return newCall;
    }
  }

  /**
   * Rule to remove an Aggregate with SINGLE_VALUE. For cases like:
   *
   * <pre>{@code
   * Aggregate(SINGLE_VALUE)
   *   Project(single expression)
   *     Aggregate
   * }</pre>
   *
   * <p>For instance, the following subtree from TPCH query 17:
   *
   * <pre>{@code
   * LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])
   *   LogicalProject(EXPR$0=[*(0.2:DECIMAL(2, 1), $0)])
   *     LogicalAggregate(group=[{}], agg#0=[AVG($0)])
   *       LogicalProject(L_QUANTITY=[$4])
   *         LogicalFilter(condition=[=($1, $cor0.P_PARTKEY)])
   *           LogicalTableScan(table=[[TPCH_01, LINEITEM]])
   * }</pre>
   *
   * <p>will be converted into:
   *
   * <pre>{@code
   * LogicalProject($f0=[*(0.2:DECIMAL(2, 1), $0)])
   *   LogicalAggregate(group=[{}], agg#0=[AVG($0)])
   *     LogicalProject(L_QUANTITY=[$4])
   *       LogicalFilter(condition=[=($1, $cor0.P_PARTKEY)])
   *         LogicalTableScan(table=[[TPCH_01, LINEITEM]])
   * }</pre>
   */
  public static final class RemoveSingleAggregateRule
      extends RelRule<RemoveSingleAggregateRule.RemoveSingleAggregateRuleConfig> {

    static final RemoveSingleAggregateRuleConfig DEFAULT =
        ImmutableRemoveSingleAggregateRuleConfig.builder()
            .withOperandSupplier(b0 ->
                b0.operand(Aggregate.class).oneInput(b1 ->
                    b1.operand(Project.class).oneInput(b2 ->
                        b2.operand(Aggregate.class).anyInputs())))
            .build();

    /** Creates a RemoveSingleAggregateRule. */
    RemoveSingleAggregateRule(RemoveSingleAggregateRuleConfig config) {
      super(config);
    }

    @Override public void onMatch(RelOptRuleCall call) {
      final Aggregate singleAggregate = call.rel(0);
      final Project project = call.rel(1);
      final Aggregate aggregate = call.rel(2);

      // check the top aggregate is a single value agg function
      if (!singleAggregate.getGroupSet().isEmpty()
          || (singleAggregate.getAggCallList().size() != 1)
          || !(singleAggregate.getAggCallList().get(0).getAggregation()
          instanceof SqlSingleValueAggFunction)) {
        return;
      }

      // check the project only projects one expression, i.e. scalar sub-queries.
      final List<RexNode> projExprs = project.getProjects();
      if (projExprs.size() != 1) {
        return;
      }

      // check the input to project is an aggregate on the entire input
      if (!aggregate.getGroupSet().isEmpty()) {
        return;
      }

      // ensure we keep the same type after removing the SINGLE_VALUE Aggregate
      final RelBuilder relBuilder = call.builder();
      relBuilder.push(aggregate)
          .project(project.getAliasedProjects(relBuilder))
          .convert(singleAggregate.getRowType(), false);
      call.transformTo(relBuilder.build());
    }

    /** Rule configuration. */
    @Value.Immutable(singleton = false)
    public interface RemoveSingleAggregateRuleConfig extends RelRule.Config {
      @Override default RemoveSingleAggregateRule toRule() {
        return new RemoveSingleAggregateRule(this);
      }
    }
  }

  /** Planner rule that removes correlations for scalar projects. */
  public static final class RemoveCorrelationForScalarProjectRule
      extends RelRule<RemoveCorrelationForScalarProjectRule
      .RemoveCorrelationForScalarProjectRuleConfig> {

    static final RemoveCorrelationForScalarProjectRuleConfig DEFAULT =
        ImmutableRemoveCorrelationForScalarProjectRuleConfig.builder()
            .withOperandSupplier(b0 ->
                b0.operand(Correlate.class).inputs(
                    b1 -> b1.operand(RelNode.class).anyInputs(),
                    b2 -> b2.operand(Aggregate.class).oneInput(b3 ->
                        b3.operand(Project.class).oneInput(b4 ->
                            b4.operand(RelNode.class).anyInputs()))))
            .build();

    /** Creates a RemoveCorrelationForScalarProjectRule. */
    RemoveCorrelationForScalarProjectRule(RemoveCorrelationForScalarProjectRuleConfig config) {
      super(config);
    }

    @Override public void onMatch(RelOptRuleCall call) {
      final RelDecorrelator d = call.getPlanner().getDecorrelator();
      final Correlate correlate = call.rel(0);
      final RelNode left = call.rel(1);
      final Aggregate aggregate = call.rel(2);
      final Project project = call.rel(3);
      RelNode right = call.rel(4);
      final RelOptCluster cluster = correlate.getCluster();

      d.setCurrent(call.getPlanner().getRoot(), correlate);

      // Check for this pattern.
      // The pattern matching could be simplified if rules can be applied
      // during decorrelation.
      //
      // Correlate(left correlation, condition = true)
      //   leftInput
      //   Aggregate (groupby (0) single_value())
      //     Project-A (may reference corVar)
      //       rightInput
      final JoinRelType joinType = correlate.getJoinType();

      // corRel.getCondition was here, however Correlate was updated so it
      // never includes a join condition. The code was not modified for brevity.
      RexNode joinCond = d.relBuilder.literal(true);
      if ((joinType != JoinRelType.LEFT)
          || (joinCond != d.relBuilder.literal(true))) {
        return;
      }

      // check that the agg is of the following type:
      // doing a single_value() on the entire input
      if (!aggregate.getGroupSet().isEmpty()
          || (aggregate.getAggCallList().size() != 1)
          || !(aggregate.getAggCallList().get(0).getAggregation()
          instanceof SqlSingleValueAggFunction)) {
        return;
      }

      // check this project only projects one expression, i.e. scalar
      // sub-queries.
      if (project.getProjects().size() != 1) {
        return;
      }

      int nullIndicatorPos;

      if ((right instanceof Filter)
          && d.cm.mapRefRelToCorRef.containsKey(right)) {
        // rightInput has this shape:
        //
        //       Filter (references corVar)
        //         filterInput

        // If rightInput is a filter and contains correlated
        // reference, make sure the correlated keys in the filter
        // condition forms a unique key of the RHS.

        Filter filter = (Filter) right;
        right = filter.getInput();

        assert right instanceof HepRelVertex;
        right = right.stripped();

        // check filter input contains no correlation
        if (!RelOptUtil.getVariablesUsed(right).isEmpty()) {
          return;
        }

        // extract the correlation out of the filter

        // First breaking up the filter conditions into equality
        // comparisons between rightJoinKeys (from the original
        // filterInput) and correlatedJoinKeys. correlatedJoinKeys
        // can be expressions, while rightJoinKeys need to be input
        // refs. These comparisons are AND'ed together.
        List<RexNode> tmpRightJoinKeys = new ArrayList<>();
        List<RexNode> correlatedJoinKeys = new ArrayList<>();
        RelOptUtil.splitCorrelatedFilterCondition(
            filter,
            tmpRightJoinKeys,
            correlatedJoinKeys,
            false);

        // check that the columns referenced in these comparisons form
        // an unique key of the filterInput
        final List<RexInputRef> rightJoinKeys = new ArrayList<>();
        for (RexNode key : tmpRightJoinKeys) {
          assert key instanceof RexInputRef;
          rightJoinKeys.add((RexInputRef) key);
        }

        // check that the columns referenced in rightJoinKeys form an
        // unique key of the filterInput
        if (rightJoinKeys.isEmpty()) {
          return;
        }

        // The join filters out the nulls.  So, it's ok if there are
        // nulls in the join keys.
        final RelMetadataQuery mq = call.getMetadataQuery();
        if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(mq, right,
            rightJoinKeys)) {
          SQL2REL_LOGGER.debug("{} are not unique keys for {}",
              rightJoinKeys, right);
          return;
        }

        RexUtil.FieldAccessFinder visitor =
            new RexUtil.FieldAccessFinder();
        RexUtil.apply(visitor, correlatedJoinKeys, null);
        List<RexFieldAccess> correlatedKeyList =
            visitor.getFieldAccessList();

        if (!d.checkCorVars(correlate, project, filter, correlatedKeyList)) {
          return;
        }

        // Change the plan to this structure.
        // Note that the Aggregate is removed.
        //
        // Project-A' (replace corVar to input ref from the Join)
        //   Join (replace corVar to input ref from leftInput)
        //     leftInput
        //     rightInput (previously filterInput)

        // Change the filter condition into a join condition
        joinCond =
            d.removeCorrelationExpr(filter.getCondition(), false);

        nullIndicatorPos =
            left.getRowType().getFieldCount()
                + rightJoinKeys.get(0).getIndex();
      } else if (d.cm.mapRefRelToCorRef.containsKey(project)) {
        // check filter input contains no correlation
        if (!RelOptUtil.getVariablesUsed(right).isEmpty()) {
          return;
        }

        if (!d.checkCorVars(correlate, project, null, null)) {
          return;
        }

        // Change the plan to this structure.
        //
        // Project-A' (replace corVar to input ref from Join)
        //   Join (left, condition = true)
        //     leftInput
        //     Aggregate(groupby(0), single_value(0), s_v(1)....)
        //       Project-B (everything from input plus literal true)
        //         projectInput

        // make the new Project to provide a null indicator
        right =
            d.createProjectWithAdditionalExprs(right,
                PairList.of(d.relBuilder.literal(true), "nullIndicator"));

        // make the new aggRel
        right =
            RelOptUtil.createSingleValueAggRel(cluster, right);

        // The last field:
        //     single_value(true)
        // is the nullIndicator
        nullIndicatorPos =
            left.getRowType().getFieldCount()
                + right.getRowType().getFieldCount() - 1;
      } else {
        return;
      }

      // make the new join rel
      final Join join = (Join) d.relBuilder.push(left).push(right)
          .join(joinType, joinCond).build();

      RelNode newProject =
          d.projectJoinOutputWithNullability(join, project, nullIndicatorPos);

      call.transformTo(newProject);

      d.removeCorVarFromTree(correlate);
    }

    /** Rule configuration. */
    @Value.Immutable(singleton = false)
    public interface RemoveCorrelationForScalarProjectRuleConfig extends RelRule.Config {
      @Override default RemoveCorrelationForScalarProjectRule toRule() {
        return new RemoveCorrelationForScalarProjectRule(this);
      }
    }
  }

  /** Planner rule that removes correlations for scalar aggregates. */
  public static final class RemoveCorrelationForScalarAggregateRule
      extends RelRule<RemoveCorrelationForScalarAggregateRule
      .RemoveCorrelationForScalarAggregateRuleConfig> {

    static final RemoveCorrelationForScalarAggregateRuleConfig DEFAULT =
        ImmutableRemoveCorrelationForScalarAggregateRuleConfig.builder()
            .withOperandSupplier(b0 ->
                b0.operand(Correlate.class).inputs(
                    b1 -> b1.operand(RelNode.class).anyInputs(),
                    b2 -> b2.operand(Project.class).oneInput(b3 ->
                        b3.operand(Aggregate.class)
                            .predicate(Aggregate::isSimple).oneInput(b4 ->
                                b4.operand(Project.class).oneInput(b5 ->
                                    b5.operand(RelNode.class).anyInputs())))))
            .build();

    /** Creates a RemoveCorrelationForScalarAggregateRule. */
    RemoveCorrelationForScalarAggregateRule(RemoveCorrelationForScalarAggregateRuleConfig config) {
      super(config);
    }

    @Override public void onMatch(RelOptRuleCall call) {
      final RelDecorrelator d = call.getPlanner().getDecorrelator();
      final Correlate correlate = call.rel(0);
      final RelNode left = call.rel(1);
      final Project aggOutputProject = call.rel(2);
      final Aggregate aggregate = call.rel(3);
      final Project aggInputProject = call.rel(4);
      RelNode right = call.rel(5);
      final RelBuilder builder = call.builder();
      final RexBuilder rexBuilder = builder.getRexBuilder();
      final RelOptCluster cluster = correlate.getCluster();

      d.setCurrent(call.getPlanner().getRoot(), correlate);

      // check for this pattern
      // The pattern matching could be simplified if rules can be applied
      // during decorrelation,
      //
      // CorrelateRel(left correlation, condition = true)
      //   leftInput
      //   Project-A (a RexNode)
      //     Aggregate (groupby (0), agg0(), agg1()...)
      //       Project-B (references coVar)
      //         rightInput

      // check aggOutputProject projects only one expression
      final List<RexNode> aggOutputProjects = aggOutputProject.getProjects();
      if (aggOutputProjects.size() != 1) {
        return;
      }

      final JoinRelType joinType = correlate.getJoinType();
      // corRel.getCondition was here, however Correlate was updated so it
      // never includes a join condition. The code was not modified for brevity.
      RexNode joinCond = rexBuilder.makeLiteral(true);
      if ((joinType != JoinRelType.LEFT)
          || (joinCond != rexBuilder.makeLiteral(true))) {
        return;
      }

      // check that the agg is on the entire input
      if (!aggregate.getGroupSet().isEmpty()) {
        return;
      }

      final List<RexNode> aggInputProjects = aggInputProject.getProjects();

      final List<AggregateCall> aggCalls = aggregate.getAggCallList();
      final Set<Integer> isCountStar = new HashSet<>();

      // mark if agg produces count(*) which needs to reference the
      // nullIndicator after the transformation.
      int k = -1;
      for (AggregateCall aggCall : aggCalls) {
        ++k;
        if (aggCall.getAggregation() instanceof SqlCountAggFunction
            && aggCall.getArgList().isEmpty()) {
          isCountStar.add(k);
        }
      }

      if ((right instanceof Filter)
          && d.cm.mapRefRelToCorRef.containsKey(right)) {
        // rightInput has this shape:
        //
        //       Filter (references corVar)
        //         filterInput
        Filter filter = (Filter) right;
        right = filter.getInput();

        assert right instanceof HepRelVertex;
        right = right.stripped();

        // check filter input contains no correlation
        if (!RelOptUtil.getVariablesUsed(right).isEmpty()) {
          return;
        }

        // check filter condition type First extract the correlation out
        // of the filter

        // First breaking up the filter conditions into equality
        // comparisons between rightJoinKeys(from the original
        // filterInput) and correlatedJoinKeys. correlatedJoinKeys
        // can only be RexFieldAccess, while rightJoinKeys can be
        // expressions. These comparisons are AND'ed together.
        List<RexNode> rightJoinKeys = new ArrayList<>();
        List<RexNode> tmpCorrelatedJoinKeys = new ArrayList<>();
        RelOptUtil.splitCorrelatedFilterCondition(
            filter,
            rightJoinKeys,
            tmpCorrelatedJoinKeys,
            true);

        // make sure the correlated reference forms a unique key check
        // that the columns referenced in these comparisons form an
        // unique key of the leftInput
        List<RexFieldAccess> correlatedJoinKeys = new ArrayList<>();
        List<RexInputRef> correlatedInputRefJoinKeys = new ArrayList<>();
        for (RexNode joinKey : tmpCorrelatedJoinKeys) {
          assert joinKey instanceof RexFieldAccess;
          correlatedJoinKeys.add((RexFieldAccess) joinKey);
          RexNode correlatedInputRef =
              d.removeCorrelationExpr(joinKey, false);
          assert correlatedInputRef instanceof RexInputRef;
          correlatedInputRefJoinKeys.add(
              (RexInputRef) correlatedInputRef);
        }

        // check that the columns referenced in rightJoinKeys form an
        // unique key of the filterInput
        if (correlatedInputRefJoinKeys.isEmpty()) {
          return;
        }

        // The join filters out the nulls.  So, it's ok if there are
        // nulls in the join keys.
        final RelMetadataQuery mq = call.getMetadataQuery();
        if (!RelMdUtil.areColumnsDefinitelyUniqueWhenNullsFiltered(mq, left,
            correlatedInputRefJoinKeys)) {
          SQL2REL_LOGGER.debug("{} are not unique keys for {}",
              correlatedJoinKeys, left);
          return;
        }

        // check corVar references are valid
        if (!d.checkCorVars(correlate, aggInputProject, filter,
            correlatedJoinKeys)) {
          return;
        }

        // Rewrite the above plan:
        //
        // Correlate(left correlation, condition = true)
        //   leftInput
        //   Project-A (a RexNode)
        //     Aggregate (groupby(0), agg0(),agg1()...)
        //       Project-B (may reference corVar)
        //         Filter (references corVar)
        //           rightInput (no correlated reference)
        //

        // to this plan:
        //
        // Project-A' (all gby keys + rewritten nullable ProjExpr)
        //   Aggregate (groupby(all left input refs)
        //                 agg0(rewritten expression),
        //                 agg1()...)
        //     Project-B' (rewritten original projected exprs)
        //       Join(replace corVar w/ input ref from leftInput)
        //         leftInput
        //         rightInput
        //

        // In the case where agg is count(*) or count($corVar), it is
        // changed to count(nullIndicator).
        // Note:  any non-nullable field from the RHS can be used as
        // the indicator however a "true" field is added to the
        // projection list from the RHS for simplicity to avoid
        // searching for non-null fields.
        //
        // Project-A' (all gby keys + rewritten nullable ProjExpr)
        //   Aggregate (groupby(all left input refs),
        //                 count(nullIndicator), other aggs...)
        //     Project-B' (all left input refs plus
        //                    the rewritten original projected exprs)
        //       Join(replace corVar to input ref from leftInput)
        //         leftInput
        //         Project (everything from rightInput plus
        //                     the nullIndicator "true")
        //           rightInput
        //

        // first change the filter condition into a join condition
        joinCond = d.removeCorrelationExpr(filter.getCondition(), false);
      } else if (d.cm.mapRefRelToCorRef.containsKey(aggInputProject)) {
        // check rightInput contains no correlation
        if (!RelOptUtil.getVariablesUsed(right).isEmpty()) {
          return;
        }

        // check corVar references are valid
        if (!d.checkCorVars(correlate, aggInputProject, null, null)) {
          return;
        }

        int nFields = left.getRowType().getFieldCount();
        ImmutableBitSet allCols = ImmutableBitSet.range(nFields);

        // leftInput contains unique keys
        // i.e. each row is distinct and can group by on all the left
        // fields
        final RelMetadataQuery mq = call.getMetadataQuery();
        if (!RelMdUtil.areColumnsDefinitelyUnique(mq, left, allCols)) {
          SQL2REL_LOGGER.debug("There are no unique keys for {}", left);
          return;
        }
        //
        // Rewrite the above plan:
        //
        // CorrelateRel(left correlation, condition = true)
        //   leftInput
        //   Project-A (a RexNode)
        //     Aggregate (groupby(0), agg0(), agg1()...)
        //       Project-B (references coVar)
        //         rightInput (no correlated reference)
        //

        // to this plan:
        //
        // Project-A' (all gby keys + rewritten nullable ProjExpr)
        //   Aggregate (groupby(all left input refs)
        //                 agg0(rewritten expression),
        //                 agg1()...)
        //     Project-B' (rewritten original projected exprs)
        //       Join (LOJ cond = true)
        //         leftInput
        //         rightInput
        //

        // In the case where agg is count($corVar), it is changed to
        // count(nullIndicator).
        // Note:  any non-nullable field from the RHS can be used as
        // the indicator however a "true" field is added to the
        // projection list from the RHS for simplicity to avoid
        // searching for non-null fields.
        //
        // Project-A' (all gby keys + rewritten nullable ProjExpr)
        //   Aggregate (groupby(all left input refs),
        //                 count(nullIndicator), other aggs...)
        //     Project-B' (all left input refs plus
        //                    the rewritten original projected exprs)
        //       Join (replace corVar to input ref from leftInput)
        //         leftInput
        //         Project (everything from rightInput plus
        //                     the nullIndicator "true")
        //           rightInput
      } else {
        return;
      }

      RelDataType leftInputFieldType = left.getRowType();
      int leftInputFieldCount = leftInputFieldType.getFieldCount();
      int joinOutputProjExprCount =
          leftInputFieldCount + aggInputProjects.size() + 1;

      right =
          d.createProjectWithAdditionalExprs(right,
              PairList.of(rexBuilder.makeLiteral(true), "nullIndicator"));

      Join join =
          (Join) d.relBuilder
              .push(left)
              .push(right)
              .join(joinType, joinCond)
              .build();

      // To the consumer of joinOutputProjRel, nullIndicator is located
      // at the end
      int nullIndicatorPos = join.getRowType().getFieldCount() - 1;

      RexInputRef nullIndicator =
          new RexInputRef(
              nullIndicatorPos,
              cluster.getTypeFactory().createTypeWithNullability(
                  join.getRowType().getFieldList()
                      .get(nullIndicatorPos).getType(),
                  true));

      // first project all group-by keys plus the transformed agg input
      List<RexNode> joinOutputProjects = new ArrayList<>();

      // LOJ Join preserves LHS types
      for (int i = 0; i < leftInputFieldCount; i++) {
        joinOutputProjects.add(
            rexBuilder.makeInputRef(
                leftInputFieldType.getFieldList().get(i).getType(), i));
      }

      for (RexNode aggInputProjExpr : aggInputProjects) {
        joinOutputProjects.add(
            d.removeCorrelationExpr(aggInputProjExpr,
                joinType.generatesNullsOnRight(),
                nullIndicator));
      }

      joinOutputProjects.add(
          rexBuilder.makeInputRef(join, nullIndicatorPos));

      final RelNode joinOutputProject = builder.push(join)
          .project(joinOutputProjects)
          .build();

      // nullIndicator is now at a different location in the output of
      // the join
      nullIndicatorPos = joinOutputProjExprCount - 1;

      final int groupCount = leftInputFieldCount;

      List<AggregateCall> newAggCalls = new ArrayList<>();
      k = -1;
      for (AggregateCall aggCall : aggCalls) {
        ++k;
        final List<Integer> argList;

        if (isCountStar.contains(k)) {
          // this is a count(*), transform it to count(nullIndicator)
          // the null indicator is located at the end
          argList = Collections.singletonList(nullIndicatorPos);
        } else {
          argList = new ArrayList<>();

          for (int aggArg : aggCall.getArgList()) {
            argList.add(aggArg + groupCount);
          }
        }

        int filterArg =
            aggCall.filterArg < 0 ? aggCall.filterArg
                : aggCall.filterArg + groupCount;
        newAggCalls.add(
            aggCall.adaptTo(joinOutputProject, argList, filterArg,
                aggregate.hasEmptyGroup(), groupCount == 0));
      }

      ImmutableBitSet groupSet =
          ImmutableBitSet.range(groupCount);
      builder.push(joinOutputProject)
          .aggregate(builder.groupKey(groupSet), newAggCalls);
      List<RexNode> newAggOutputProjectList = new ArrayList<>();
      for (int i : groupSet) {
        newAggOutputProjectList.add(
            rexBuilder.makeInputRef(builder.peek(), i));
      }

      RexNode newAggOutputProjects =
          d.removeCorrelationExpr(aggOutputProjects.get(0), false);
      newAggOutputProjectList.add(
          rexBuilder.makeCast(
              cluster.getTypeFactory().createTypeWithNullability(
                  newAggOutputProjects.getType(),
                  true),
              newAggOutputProjects));

      builder.project(newAggOutputProjectList);
      call.transformTo(builder.build());

      d.removeCorVarFromTree(correlate);
    }

    /** Rule configuration. */
    @Value.Immutable(singleton = false)
    public interface RemoveCorrelationForScalarAggregateRuleConfig extends RelRule.Config {
      @Override default RemoveCorrelationForScalarAggregateRule toRule() {
        return new RemoveCorrelationForScalarAggregateRule(this);
      }
    }
  }

  // REVIEW jhyde 29-Oct-2007: This rule is non-static, depends on the state
  // of members in RelDecorrelator, and has side-effects in the decorrelator.
  // This breaks the contract of a planner rule, and the rule will not be
  // reusable in other planners.

  // REVIEW jvs 29-Oct-2007:  Shouldn't it also be incorporating
  // the flavor attribute into the description?

  /** Planner rule that adjusts projects when counts are added. */
  public static final class AdjustProjectForCountAggregateRule
      extends RelRule<AdjustProjectForCountAggregateRule.AdjustProjectForCountAggregateRuleConfig> {

    static final AdjustProjectForCountAggregateRuleConfig DEFAULT_WITH_FAVLOR =
        ImmutableAdjustProjectForCountAggregateRuleConfig.builder()
            .withOperandSupplier(b0 ->
                b0.operand(Correlate.class).inputs(
                    b1 -> b1.operand(RelNode.class).anyInputs(),
                    b2 -> b2.operand(Project.class)
                        .oneInput(b3 -> b3.operand(Aggregate.class).anyInputs())))
            .withFlavor(true)
            .build();

    static final AdjustProjectForCountAggregateRuleConfig DEFAULT_WITHOUT_FAVLOR =
        ImmutableAdjustProjectForCountAggregateRuleConfig.builder()
            .withOperandSupplier(b0 ->
                b0.operand(Correlate.class).inputs(
                    b1 -> b1.operand(RelNode.class).anyInputs(),
                    b2 -> b2.operand(Aggregate.class).anyInputs()))
            .withFlavor(false)
            .build();

    /** Creates an AdjustProjectForCountAggregateRule. */
    AdjustProjectForCountAggregateRule(AdjustProjectForCountAggregateRuleConfig config) {
      super(config);
    }

    @Override public void onMatch(RelOptRuleCall call) {
      final RelDecorrelator d = call.getPlanner().getDecorrelator();
      final Correlate correlate = call.rel(0);
      final RelNode left = call.rel(1);
      final Project aggOutputProject;
      final Aggregate aggregate;
      if (config.flavor()) {
        aggOutputProject = call.rel(2);
        aggregate = call.rel(3);
      } else {
        aggregate = call.rel(2);

        // Create identity projection
        final PairList<RexNode, String> projects = PairList.of();
        final List<RelDataTypeField> fields =
            aggregate.getRowType().getFieldList();
        for (int i = 0; i < fields.size(); i++) {
          RexInputRef.add2(projects, projects.size(), fields);
        }
        final RelBuilder relBuilder = call.builder();
        relBuilder.push(aggregate)
            .projectNamed(projects.leftList(), projects.rightList(), true);
        aggOutputProject = (Project) relBuilder.build();
      }
      onMatch2(d, call, correlate, left, aggOutputProject, aggregate);
    }

    private static void onMatch2(
        RelDecorrelator d,
        RelOptRuleCall call,
        Correlate correlate,
        RelNode leftInput,
        Project aggOutputProject,
        Aggregate aggregate) {
      if (d.generatedCorRels.contains(correlate)) {
        // This Correlate was generated by a previous invocation of
        // this rule. No further work to do.
        return;
      }

      d.setCurrent(call.getPlanner().getRoot(), correlate);

      // check for this pattern
      // The pattern matching could be simplified if rules can be applied
      // during decorrelation,
      //
      // CorrelateRel(left correlation, condition = true)
      //   leftInput
      //   Project-A (a RexNode)
      //     Aggregate (groupby (0), agg0(), agg1()...)

      // check aggOutputProj projects only one expression
      List<RexNode> aggOutputProjExprs = aggOutputProject.getProjects();
      if (aggOutputProjExprs.size() != 1) {
        return;
      }

      JoinRelType joinType = correlate.getJoinType();
      // corRel.getCondition was here, however Correlate was updated so it
      // never includes a join condition. The code was not modified for brevity.
      RexNode joinCond = d.relBuilder.literal(true);
      if ((joinType != JoinRelType.LEFT)
          || (joinCond != d.relBuilder.literal(true))) {
        return;
      }

      // check that the agg is on the entire input
      if (!aggregate.getGroupSet().isEmpty()) {
        return;
      }

      List<AggregateCall> aggCalls = aggregate.getAggCallList();
      Set<Integer> isCount = new HashSet<>();

      // remember the count() positions
      int i = -1;
      for (AggregateCall aggCall : aggCalls) {
        ++i;
        if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
          isCount.add(i);
        }
      }

      // now rewrite the plan to
      //
      // Project-A' (all LHS plus transformed original projections,
      //             replacing references to count() with case statement)
      //   Correlate(left correlation, condition = true)
      //     leftInput
      //     Aggregate(groupby (0), agg0(), agg1()...)
      //
      final RexBuilder rexBuilder = d.relBuilder.getRexBuilder();
      List<RexNode> requiredNodes =
          correlate.getRequiredColumns().asList().stream()
              .map(ord -> rexBuilder.makeInputRef(correlate, ord))
              .collect(Collectors.toList());
      Correlate newCorrelate = (Correlate) d.relBuilder.push(leftInput)
          .push(aggregate).correlate(correlate.getJoinType(),
              correlate.getCorrelationId(),
              requiredNodes).build();


      // remember this rel so we don't fire rule on it again
      // REVIEW jhyde 29-Oct-2007: rules should not save state; rule
      // should recognize patterns where it does or does not need to do
      // work
      d.generatedCorRels.add(newCorrelate);

      // need to update the mapCorToCorRel Update the output position
      // for the corVars: only pass on the corVars that are not used in
      // the join key.
      if (d.cm.mapCorToCorRel.get(correlate.getCorrelationId()) == correlate) {
        d.cm.mapCorToCorRel.put(correlate.getCorrelationId(), newCorrelate);
      }

      RelNode newOutput =
          d.aggregateCorrelatorOutput(newCorrelate, aggOutputProject, isCount);

      call.transformTo(newOutput);
    }

    /** Rule configuration. */
    @Value.Immutable(singleton = false)
    public interface AdjustProjectForCountAggregateRuleConfig extends RelRule.Config {
      @Override default AdjustProjectForCountAggregateRule toRule() {
        return new AdjustProjectForCountAggregateRule(this);
      }

      /** Returns the flavor of the rule (true for 4 operands, false for 3
       * operands). */
      boolean flavor();

      /** Sets {@link #flavor}. */
      AdjustProjectForCountAggregateRuleConfig withFlavor(boolean flavor);
    }
  }

  /**
   * A unique reference to a correlation field.
   *
   * <p>For instance, if a RelNode references emp.name multiple times, it would
   * result in multiple {@code CorRef} objects that differ just in
   * {@link CorRef#uniqueKey}.
   */
  static class CorRef implements Comparable<CorRef> {
    public final int uniqueKey;
    public final CorrelationId corr;
    public final int field;

    CorRef(CorrelationId corr, int field, int uniqueKey) {
      this.corr = corr;
      this.field = field;
      this.uniqueKey = uniqueKey;
    }

    @Override public String toString() {
      return corr.getName() + '.' + field;
    }

    @Override public int hashCode() {
      return Objects.hash(uniqueKey, corr, field);
    }

    @Override public boolean equals(@Nullable Object o) {
      return this == o
          || o instanceof CorRef
          && uniqueKey == ((CorRef) o).uniqueKey
          && corr == ((CorRef) o).corr
          && field == ((CorRef) o).field;
    }

    @Override public int compareTo(CorRef o) {
      int c = corr.compareTo(o.corr);
      if (c != 0) {
        return c;
      }
      c = Integer.compare(field, o.field);
      if (c != 0) {
        return c;
      }
      return Integer.compare(uniqueKey, o.uniqueKey);
    }

    public CorDef def() {
      return new CorDef(corr, field);
    }
  }

  /** A correlation and a field. */
  static class CorDef implements Comparable<CorDef> {
    public final CorrelationId corr;
    public final int field;

    CorDef(CorrelationId corr, int field) {
      this.corr = corr;
      this.field = field;
    }

    @Override public String toString() {
      return corr.getName() + '.' + field;
    }

    @Override public int hashCode() {
      return Objects.hash(corr, field);
    }

    @Override public boolean equals(@Nullable Object o) {
      return this == o
          || o instanceof CorDef
          && corr == ((CorDef) o).corr
          && field == ((CorDef) o).field;
    }

    @Override public int compareTo(CorDef o) {
      int c = corr.compareTo(o.corr);
      if (c != 0) {
        return c;
      }
      return Integer.compare(field, o.field);
    }
  }

  /** A map of the locations of
   * {@link org.apache.calcite.rel.core.Correlate}
   * in a tree of {@link RelNode}s.
   *
   * <p>It is used to drive the decorrelation process.
   * Treat it as immutable; rebuild if you modify the tree.
   *
   * <p>There are three maps:<ol>
   *
   * <li>{@link #mapRefRelToCorRef} maps a {@link RelNode} to the correlated
   * variables it references;
   *
   * <li>{@link #mapCorToCorRel} maps a correlated variable to the
   * {@link Correlate} providing it;
   *
   * <li>{@link #mapFieldAccessToCorRef} maps a rex field access to
   * the corVar it represents. Because typeFlattener does not clone or
   * modify a correlated field access this map does not need to be
   * updated.
   *
   * </ol> */
  protected static class CorelMap {
    private final Multimap<RelNode, CorRef> mapRefRelToCorRef;
    private final NavigableMap<CorrelationId, RelNode> mapCorToCorRel;
    private final Map<RexFieldAccess, CorRef> mapFieldAccessToCorRef;

    // TODO: create immutable copies of all maps
    private CorelMap(Multimap<RelNode, CorRef> mapRefRelToCorRef,
        NavigableMap<CorrelationId, RelNode> mapCorToCorRel,
        Map<RexFieldAccess, CorRef> mapFieldAccessToCorRef) {
      this.mapRefRelToCorRef = mapRefRelToCorRef;
      this.mapCorToCorRel = mapCorToCorRel;
      this.mapFieldAccessToCorRef = ImmutableMap.copyOf(mapFieldAccessToCorRef);
    }

    @Override public String toString() {
      return "mapRefRelToCorRef=" + mapRefRelToCorRef
          + "\nmapCorToCorRel=" + mapCorToCorRel
          + "\nmapFieldAccessToCorRef=" + mapFieldAccessToCorRef
          + "\n";
    }

    @SuppressWarnings("UndefinedEquals")
    @Override public boolean equals(@Nullable Object obj) {
      return obj == this
          || obj instanceof CorelMap
          // TODO: Multimap does not have well-defined equals behavior
          && mapRefRelToCorRef.equals(((CorelMap) obj).mapRefRelToCorRef)
          && mapCorToCorRel.equals(((CorelMap) obj).mapCorToCorRel)
          && mapFieldAccessToCorRef.equals(
              ((CorelMap) obj).mapFieldAccessToCorRef);
    }

    @Override public int hashCode() {
      return Objects.hash(mapRefRelToCorRef, mapCorToCorRel,
          mapFieldAccessToCorRef);
    }

    /** Creates a CorelMap with given contents. */
    public static CorelMap of(
        SortedSetMultimap<RelNode, CorRef> mapRefRelToCorVar,
        NavigableMap<CorrelationId, RelNode> mapCorToCorRel,
        Map<RexFieldAccess, CorRef> mapFieldAccessToCorVar) {
      return new CorelMap(mapRefRelToCorVar, mapCorToCorRel,
          mapFieldAccessToCorVar);
    }

    public NavigableMap<CorrelationId, RelNode> getMapCorToCorRel() {
      return mapCorToCorRel;
    }

    /**
     * Returns whether there are any correlating variables in this statement.
     *
     * @return whether there are any correlating variables
     */
    public boolean hasCorrelation() {
      return !mapCorToCorRel.isEmpty();
    }
  }

  /** Builds a {@link org.apache.calcite.sql2rel.RelDecorrelator.CorelMap}. */
  public static class CorelMapBuilder extends RelHomogeneousShuttle {
    final NavigableMap<CorrelationId, RelNode> mapCorToCorRel =
        new TreeMap<>();

    final SortedSetMultimap<RelNode, CorRef> mapRefRelToCorRef =
        MultimapBuilder.SortedSetMultimapBuilder.hashKeys()
            .treeSetValues()
            .build();

    final Map<RexFieldAccess, CorRef> mapFieldAccessToCorVar = new HashMap<>();

    final Holder<Integer> offset = Holder.of(0);
    int corrIdGenerator = 0;

    /** Creates a CorelMap by iterating over a {@link RelNode} tree. */
    public CorelMap build(RelNode... rels) {
      for (RelNode rel : rels) {
        stripHep(rel).accept(this);
      }
      return new CorelMap(mapRefRelToCorRef, mapCorToCorRel,
          mapFieldAccessToCorVar);
    }

    @Override public RelNode visit(RelNode other) {
      if (other instanceof Join) {
        Join join = (Join) other;
        try {
          stack.push(join);
          join.getCondition().accept(rexVisitor(join));
        } finally {
          stack.pop();
        }
        return visitJoin(join);
      } else if (other instanceof Correlate) {
        Correlate correlate = (Correlate) other;
        mapCorToCorRel.put(correlate.getCorrelationId(), correlate);
        return visitJoin(correlate);
      } else if (other instanceof Filter) {
        Filter filter = (Filter) other;
        try {
          stack.push(filter);
          filter.getCondition().accept(rexVisitor(filter));
        } finally {
          stack.pop();
        }
      } else if (other instanceof Project) {
        Project project = (Project) other;
        try {
          stack.push(project);
          for (RexNode node : project.getProjects()) {
            node.accept(rexVisitor(project));
          }
        } finally {
          stack.pop();
        }
      }
      return super.visit(other);
    }

    @Override protected RelNode visitChild(RelNode parent, int i,
        RelNode input) {
      return super.visitChild(parent, i, stripHep(input));
    }

    private RelNode visitJoin(BiRel join) {
      final int x = offset.get();
      visitChild(join, 0, join.getLeft());
      offset.set(x + join.getLeft().getRowType().getFieldCount());
      visitChild(join, 1, join.getRight());
      offset.set(x);
      return join;
    }

    private RexVisitorImpl<Void> rexVisitor(final RelNode rel) {
      return new RexVisitorImpl<Void>(true) {
        @Override public Void visitFieldAccess(RexFieldAccess fieldAccess) {
          final RexNode ref = fieldAccess.getReferenceExpr();
          if (ref instanceof RexCorrelVariable) {
            final RexCorrelVariable var = (RexCorrelVariable) ref;
            if (mapFieldAccessToCorVar.containsKey(fieldAccess)) {
              // for cases where different Rel nodes are referring to
              // same correlation var (e.g. in case of NOT IN)
              // avoid generating another correlation var
              // and record the 'rel' is using the same correlation
              mapRefRelToCorRef.put(rel,
                  mapFieldAccessToCorVar.get(fieldAccess));
            } else {
              final CorRef correlation =
                  new CorRef(var.id, fieldAccess.getField().getIndex(),
                      corrIdGenerator++);
              mapFieldAccessToCorVar.put(fieldAccess, correlation);
              mapRefRelToCorRef.put(rel, correlation);
            }
          }
          return super.visitFieldAccess(fieldAccess);
        }

        @Override public Void visitSubQuery(RexSubQuery subQuery) {
          subQuery.rel.accept(CorelMapBuilder.this);
          return super.visitSubQuery(subQuery);
        }
      };
    }
  }

  /** Frame describing the relational expression after decorrelation
   * and where to find the output fields and correlation variables
   * among its output fields. */
  static class Frame {
    final RelNode r;
    final ImmutableSortedMap<CorDef, Integer> corDefOutputs;
    final ImmutableSortedMap<Integer, Integer> oldToNewOutputs;

    Frame(RelNode oldRel, RelNode r, NavigableMap<CorDef, Integer> corDefOutputs,
        Map<Integer, Integer> oldToNewOutputs) {
      this.r = requireNonNull(r, "r");
      this.corDefOutputs = ImmutableSortedMap.copyOf(corDefOutputs);
      this.oldToNewOutputs = ImmutableSortedMap.copyOf(oldToNewOutputs);
      assert allLessThan(this.corDefOutputs.values(),
          r.getRowType().getFieldCount(), Litmus.THROW);
      assert allLessThan(this.oldToNewOutputs.keySet(),
          oldRel.getRowType().getFieldCount(), Litmus.THROW);
      assert allLessThan(this.oldToNewOutputs.values(),
          r.getRowType().getFieldCount(), Litmus.THROW);
      RelDataType rowType = oldRel.getRowType();
      assert this.oldToNewOutputs.size() >= rowType.getFieldCount();
    }
  }

  /**
   * Check if the field at the given index is non-nullable.
   *
   * <p>This method performs a basic check for `null` values in the field. However, a
   * `false` result does not necessarily mean that the field contains `null` values.
   * It only guarantees that if the result is `true`, the field contains no `null` values.
   */
  private static boolean isFieldNotNull(RelNode rel, int index) {
    RelDataType type = rel.getRowType().getFieldList().get(index).getType();
    return !type.isNullable() || isFieldNotNullRecursive(rel, index);
  }

  private static boolean isFieldNotNullRecursive(RelNode rel, int index) {
    if (rel instanceof Project) {
      Project project = (Project) rel;

      RexNode expr = project.getProjects().get(index);
      if (!(expr instanceof RexInputRef)) {
        return false;
      }
      return isFieldNotNullRecursive(project.getInput(), ((RexInputRef) expr).getIndex());
    } else if (rel instanceof Aggregate) {
      Aggregate agg = (Aggregate) rel;
      ImmutableBitSet groupSet = agg.getGroupSet();

      if (index >= groupSet.size()) {
        return false;
      }
      return isFieldNotNullRecursive(agg.getInput(), groupSet.asList().get(index));
    } else if (rel instanceof Filter) {
      Filter filter = (Filter) rel;
      if (Strong.isNotTrue(filter.getCondition(), ImmutableBitSet.of(index))) {
        return true;
      }
      return isFieldNotNullRecursive(filter.getInput(), index);
    } else if (rel instanceof Join) {
      Join join = (Join) rel;
      int leftFieldCnt = join.getLeft().getRowType().getFieldCount();
      if (index < join.getLeft().getRowType().getFieldCount()) {
        if (!join.getJoinType().generatesNullsOnLeft()) {
          return Strong.isNotTrue(join.getCondition(), ImmutableBitSet.of(index))
              || isFieldNotNullRecursive(join.getLeft(), index);
        }
      } else {
        if (!join.getJoinType().generatesNullsOnRight()) {
          return Strong.isNotTrue(join.getCondition(), ImmutableBitSet.of(index))
              || isFieldNotNullRecursive(join.getRight(), index - leftFieldCnt);
        }
      }
      return false;
    } else {
      return false;
    }
  }

  // -------------------------------------------------------------------------
  //  Getter/Setter
  // -------------------------------------------------------------------------

  /**
   * Returns the {@code visitor} on which the {@code MethodDispatcher} dispatches
   * each {@code decorrelateRel} method, the default implementation returns this instance,
   * if you got a sub-class, override this method to replace the {@code visitor} as the
   * sub-class instance.
   */
  protected RelDecorrelator getVisitor() {
    return this;
  }

  /** Returns the rules applied on the rel after decorrelation, never null. */
  protected Collection<RelOptRule> getPostDecorrelateRules() {
    return Collections.emptyList();
  }
}
